diff options
Diffstat (limited to 'pkg/state/decode.go')
-rw-r--r-- | pkg/state/decode.go | 918 |
1 files changed, 517 insertions, 401 deletions
diff --git a/pkg/state/decode.go b/pkg/state/decode.go index 590c241a3..c9971cdf6 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -17,28 +17,49 @@ package state import ( "bytes" "context" - "encoding/binary" - "errors" "fmt" - "io" + "math" "reflect" - "sort" - "github.com/golang/protobuf/proto" - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" + "gvisor.dev/gvisor/pkg/state/wire" ) -// objectState represents an object that may be in the process of being +// internalCallback is a interface called on object completion. +// +// There are two implementations: objectDecodeState & userCallback. +type internalCallback interface { + // source returns the dependent object. May be nil. + source() *objectDecodeState + + // callbackRun executes the callback. + callbackRun() +} + +// userCallback is an implementation of internalCallback. +type userCallback func() + +// source implements internalCallback.source. +func (userCallback) source() *objectDecodeState { + return nil +} + +// callbackRun implements internalCallback.callbackRun. +func (uc userCallback) callbackRun() { + uc() +} + +// objectDecodeState represents an object that may be in the process of being // decoded. Specifically, it represents either a decoded object, or an an // interest in a future object that will be decoded. When that interest is // registered (via register), the storage for the object will be created, but // it will not be decoded until the object is encountered in the stream. -type objectState struct { +type objectDecodeState struct { // id is the id for this object. - // - // If this field is zero, then this is an anonymous (unregistered, - // non-reference primitive) object. This is immutable. - id uint64 + id objectID + + // typ is the id for this typeID. This may be zero if this is not a + // type-registered structure. + typ typeID // obj is the object. This may or may not be valid yet, depending on // whether complete returns true. However, regardless of whether the @@ -57,69 +78,52 @@ type objectState struct { // blockedBy is the number of dependencies this object has. blockedBy int - // blocking is a list of the objects blocked by this one. - blocking []*objectState + // callbacksInline is inline storage for callbacks. + callbacksInline [2]internalCallback // callbacks is a set of callbacks to execute on load. - callbacks []func() - - // path is the decoding path to the object. - path recoverable -} - -// complete indicates the object is complete. -func (os *objectState) complete() bool { - return os.blockedBy == 0 && len(os.callbacks) == 0 -} - -// checkComplete checks for completion. If the object is complete, pending -// callbacks will be executed and checkComplete will be called on downstream -// objects (those depending on this one). -func (os *objectState) checkComplete(stats *Stats) { - if os.blockedBy > 0 { - return - } - stats.Start(os.obj) + callbacks []internalCallback - // Fire all callbacks. - for _, fn := range os.callbacks { - fn() - } - os.callbacks = nil - - // Clear all blocked objects. - for _, other := range os.blocking { - other.blockedBy-- - other.checkComplete(stats) - } - os.blocking = nil - stats.Done() + completeEntry } -// waitFor queues a dependency on the given object. -func (os *objectState) waitFor(other *objectState, callback func()) { - os.blockedBy++ - other.blocking = append(other.blocking, os) - if callback != nil { - other.callbacks = append(other.callbacks, callback) +// addCallback adds a callback to the objectDecodeState. +func (ods *objectDecodeState) addCallback(ic internalCallback) { + if ods.callbacks == nil { + ods.callbacks = ods.callbacksInline[:0] } + ods.callbacks = append(ods.callbacks, ic) } // findCycleFor returns when the given object is found in the blocking set. -func (os *objectState) findCycleFor(target *objectState) []*objectState { - for _, other := range os.blocking { - if other == target { - return []*objectState{target} +func (ods *objectDecodeState) findCycleFor(target *objectDecodeState) []*objectDecodeState { + for _, ic := range ods.callbacks { + other := ic.source() + if other != nil && other == target { + return []*objectDecodeState{target} } else if childList := other.findCycleFor(target); childList != nil { return append(childList, other) } } - return nil + + // This should not occur. + Failf("no deadlock found?") + panic("unreachable") } // findCycle finds a dependency cycle. -func (os *objectState) findCycle() []*objectState { - return append(os.findCycleFor(os), os) +func (ods *objectDecodeState) findCycle() []*objectDecodeState { + return append(ods.findCycleFor(ods), ods) +} + +// source implements internalCallback.source. +func (ods *objectDecodeState) source() *objectDecodeState { + return ods +} + +// callbackRun implements internalCallback.callbackRun. +func (ods *objectDecodeState) callbackRun() { + ods.blockedBy-- } // decodeState is a graph of objects in the process of being decoded. @@ -137,30 +141,66 @@ type decodeState struct { // ctx is the decode context. ctx context.Context + // r is the input stream. + r wire.Reader + + // types is the type database. + types typeDecodeDatabase + // objectByID is the set of objects in progress. - objectsByID map[uint64]*objectState + objectsByID []*objectDecodeState // deferred are objects that have been read, by no interest has been // registered yet. These will be decoded once interest in registered. - deferred map[uint64]*pb.Object + deferred map[objectID]wire.Object - // outstanding is the number of outstanding objects. - outstanding uint32 + // pending is the set of objects that are not yet complete. + pending completeList - // r is the input stream. - r io.Reader - - // stats is the passed stats object. - stats *Stats - - // recoverable is the panic recover facility. - recoverable + // stats tracks time data. + stats Stats } // lookup looks up an object in decodeState or returns nil if no such object // has been previously registered. -func (ds *decodeState) lookup(id uint64) *objectState { - return ds.objectsByID[id] +func (ds *decodeState) lookup(id objectID) *objectDecodeState { + if len(ds.objectsByID) < int(id) { + return nil + } + return ds.objectsByID[id-1] +} + +// checkComplete checks for completion. +func (ds *decodeState) checkComplete(ods *objectDecodeState) bool { + // Still blocked? + if ods.blockedBy > 0 { + return false + } + + // Track stats if relevant. + if ods.callbacks != nil && ods.typ != 0 { + ds.stats.start(ods.typ) + defer ds.stats.done() + } + + // Fire all callbacks. + for _, ic := range ods.callbacks { + ic.callbackRun() + } + + // Mark completed. + cbs := ods.callbacks + ods.callbacks = nil + ds.pending.Remove(ods) + + // Recursively check others. + for _, ic := range cbs { + if other := ic.source(); other != nil && other.blockedBy == 0 { + ds.checkComplete(other) + } + } + + return true // All set. } // wait registers a dependency on an object. @@ -168,11 +208,8 @@ func (ds *decodeState) lookup(id uint64) *objectState { // As a special case, we always allow _useable_ references back to the first // decoding object because it may have fields that are already decoded. We also // allow trivial self reference, since they can be handled internally. -func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) { +func (ds *decodeState) wait(waiter *objectDecodeState, id objectID, callback func()) { switch id { - case 0: - // Nil pointer; nothing to wait for. - fallthrough case waiter.id: // Trivial self reference. fallthrough @@ -184,107 +221,188 @@ func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) { return } + // Mark as blocked. + waiter.blockedBy++ + // No nil can be returned here. - waiter.waitFor(ds.lookup(id), callback) + other := ds.lookup(id) + if callback != nil { + // Add the additional user callback. + other.addCallback(userCallback(callback)) + } + + // Mark waiter as unblocked. + other.addCallback(waiter) } // waitObject notes a blocking relationship. -func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) { - if rv, ok := p.Value.(*pb.Object_RefValue); ok { +func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, callback func()) { + if rv, ok := encoded.(*wire.Ref); ok && rv.Root != 0 { // Refs can encode pointers and maps. - ds.wait(os, rv.RefValue, callback) - } else if sv, ok := p.Value.(*pb.Object_SliceValue); ok { + ds.wait(ods, objectID(rv.Root), callback) + } else if sv, ok := encoded.(*wire.Slice); ok && sv.Ref.Root != 0 { // See decodeObject; we need to wait for the array (if non-nil). - ds.wait(os, sv.SliceValue.RefValue, callback) - } else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok { + ds.wait(ods, objectID(sv.Ref.Root), callback) + } else if iv, ok := encoded.(*wire.Interface); ok { // It's an interface (wait recurisvely). - ds.waitObject(os, iv.InterfaceValue.Value, callback) + ds.waitObject(ods, iv.Value, callback) } else if callback != nil { // Nothing to wait for: execute the callback immediately. callback() } } +// walkChild returns a child object from obj, given an accessor path. This is +// the decode-side equivalent to traverse in encode.go. +// +// For the purposes of this function, a child object is either a field within a +// struct or an array element, with one such indirection per element in +// path. The returned value may be an unexported field, so it may not be +// directly assignable. See unsafePointerTo. +func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value { + // See wire.Ref.Dots. The path here is specified in reverse order. + for i := len(path) - 1; i >= 0; i-- { + switch pc := path[i].(type) { + case *wire.FieldName: // Must be a pointer. + if obj.Kind() != reflect.Struct { + Failf("next component in child path is a field name, but the current object is not a struct. Path: %v, current obj: %#v", path, obj) + } + obj = obj.FieldByName(string(*pc)) + case wire.Index: // Embedded. + if obj.Kind() != reflect.Array { + Failf("next component in child path is an array index, but the current object is not an array. Path: %v, current obj: %#v", path, obj) + } + obj = obj.Index(int(pc)) + default: + panic("unreachable: switch should be exhaustive") + } + } + return obj +} + // register registers a decode with a type. // // This type is only used to instantiate a new object if it has not been -// registered previously. -func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState { - os, ok := ds.objectsByID[id] - if ok { - return os +// registered previously. This depends on the type provided if none is +// available in the object itself. +func (ds *decodeState) register(r *wire.Ref, typ reflect.Type) reflect.Value { + // Grow the objectsByID slice. + id := objectID(r.Root) + if len(ds.objectsByID) < int(id) { + ds.objectsByID = append(ds.objectsByID, make([]*objectDecodeState, int(id)-len(ds.objectsByID))...) + } + + // Does this object already exist? + ods := ds.objectsByID[id-1] + if ods != nil { + return walkChild(r.Dots, ods.obj) + } + + // Create the object. + if len(r.Dots) != 0 { + typ = ds.findType(r.Type) } + v := reflect.New(typ) + ods = &objectDecodeState{ + id: id, + obj: v.Elem(), + } + ds.objectsByID[id-1] = ods + ds.pending.PushBack(ods) - // Record in the object index. - if typ.Kind() == reflect.Map { - os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()} - } else { - os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()} + // Process any deferred objects & callbacks. + if encoded, ok := ds.deferred[id]; ok { + delete(ds.deferred, id) + ds.decodeObject(ods, ods.obj, encoded) } - ds.objectsByID[id] = os - if o, ok := ds.deferred[id]; ok { - // There is a deferred object. - delete(ds.deferred, id) // Free memory. - ds.decodeObject(os, os.obj, o, "", nil) - } else { - // There is no deferred object. - ds.outstanding++ + return walkChild(r.Dots, ods.obj) +} + +// objectDecoder is for decoding structs. +type objectDecoder struct { + // ds is decodeState. + ds *decodeState + + // ods is current object being decoded. + ods *objectDecodeState + + // reconciledTypeEntry is the reconciled type information. + rte *reconciledTypeEntry + + // encoded is the encoded object state. + encoded *wire.Struct +} + +// load is helper for the public methods on Source. +func (od *objectDecoder) load(slot int, objPtr reflect.Value, wait bool, fn func()) { + // Note that we have reconciled the type and may remap the fields here + // to match what's expected by the decoder. The "slot" parameter here + // is in terms of the local type, where the fields in the encoded + // object are in terms of the wire object's type, which might be in a + // different order (but will have the same fields). + v := *od.encoded.Field(od.rte.FieldOrder[slot]) + od.ds.decodeObject(od.ods, objPtr.Elem(), v) + if wait { + // Mark this individual object a blocker. + od.ds.waitObject(od.ods, v, fn) } +} - return os +// aterLoad implements Source.AfterLoad. +func (od *objectDecoder) afterLoad(fn func()) { + // Queue the local callback; this will execute when all of the above + // data dependencies have been cleared. + od.ods.addCallback(userCallback(fn)) } // decodeStruct decodes a struct value. -func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) { - // Set the fields. - m := Map{newInternalMap(nil, ds, os)} - defer internalMapPool.Put(m.internalMap) - for _, field := range s.Fields { - m.data = append(m.data, entry{ - name: field.Name, - object: field.Value, - }) - } - - // Sort the fields for efficient searching. - // - // Technically, these should already appear in sorted order in the - // state ordering, so this cost is effectively a single scan to ensure - // that the order is correct. - if len(m.data) > 1 { - sort.Slice(m.data, func(i, j int) bool { - return m.data[i].name < m.data[j].name - }) - } - - // Invoke the load; this will recursively decode other objects. - fns, ok := registeredTypes.lookupFns(obj.Addr().Type()) - if ok { - // Invoke the loader. - fns.invokeLoad(obj.Addr(), m) - } else if obj.NumField() == 0 { - // Allow anonymous empty structs. - return - } else { +func (ds *decodeState) decodeStruct(ods *objectDecodeState, obj reflect.Value, encoded *wire.Struct) { + if encoded.TypeID == 0 { + // Allow anonymous empty structs, but only if the encoded + // object also has no fields. + if encoded.Fields() == 0 && obj.NumField() == 0 { + return + } + // Propagate an error. - panic(fmt.Errorf("unregistered type %s", obj.Type())) + Failf("empty struct on wire %#v has field mismatch with type %q", encoded, obj.Type().Name()) + } + + // Lookup the object type. + rte := ds.types.Lookup(typeID(encoded.TypeID), obj.Type()) + ods.typ = typeID(encoded.TypeID) + + // Invoke the loader. + od := objectDecoder{ + ds: ds, + ods: ods, + rte: rte, + encoded: encoded, + } + ds.stats.start(ods.typ) + defer ds.stats.done() + if sl, ok := obj.Addr().Interface().(SaverLoader); ok { + // Note: may be a registered empty struct which does not + // implement the saver/loader interfaces. + sl.StateLoad(Source{internal: od}) } } // decodeMap decodes a map value. -func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) { +func (ds *decodeState) decodeMap(ods *objectDecodeState, obj reflect.Value, encoded *wire.Map) { if obj.IsNil() { + // See pointerTo. obj.Set(reflect.MakeMap(obj.Type())) } - for i := 0; i < len(m.Keys); i++ { + for i := 0; i < len(encoded.Keys); i++ { // Decode the objects. kv := reflect.New(obj.Type().Key()).Elem() vv := reflect.New(obj.Type().Elem()).Elem() - ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i) - ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface()) - ds.waitObject(os, m.Keys[i], nil) - ds.waitObject(os, m.Values[i], nil) + ds.decodeObject(ods, kv, encoded.Keys[i]) + ds.decodeObject(ods, vv, encoded.Values[i]) + ds.waitObject(ods, encoded.Keys[i], nil) + ds.waitObject(ods, encoded.Values[i], nil) // Set in the map. obj.SetMapIndex(kv, vv) @@ -292,271 +410,294 @@ func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) } // decodeArray decodes an array value. -func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) { - if len(a.Contents) != obj.Len() { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents))) +func (ds *decodeState) decodeArray(ods *objectDecodeState, obj reflect.Value, encoded *wire.Array) { + if len(encoded.Contents) != obj.Len() { + Failf("mismatching array length expect=%d, actual=%d", obj.Len(), len(encoded.Contents)) } // Decode the contents into the array. - for i := 0; i < len(a.Contents); i++ { - ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i) - ds.waitObject(os, a.Contents[i], nil) + for i := 0; i < len(encoded.Contents); i++ { + ds.decodeObject(ods, obj.Index(i), encoded.Contents[i]) + ds.waitObject(ods, encoded.Contents[i], nil) } } -// decodeInterface decodes an interface value. -func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) { - // Is this a nil value? - if i.Type == "" { - return // Just leave obj alone. +// findType finds the type for the given wire.TypeSpecs. +func (ds *decodeState) findType(t wire.TypeSpec) reflect.Type { + switch x := t.(type) { + case wire.TypeID: + typ := ds.types.LookupType(typeID(x)) + rte := ds.types.Lookup(typeID(x), typ) + return rte.LocalType + case *wire.TypeSpecPointer: + return reflect.PtrTo(ds.findType(x.Type)) + case *wire.TypeSpecArray: + return reflect.ArrayOf(int(x.Count), ds.findType(x.Type)) + case *wire.TypeSpecSlice: + return reflect.SliceOf(ds.findType(x.Type)) + case *wire.TypeSpecMap: + return reflect.MapOf(ds.findType(x.Key), ds.findType(x.Value)) + default: + // Should not happen. + Failf("unknown type %#v", t) } + panic("unreachable") +} - // Get the dispatchable type. This may not be used if the given - // reference has already been resolved, but if not we need to know the - // type to create. - t, ok := registeredTypes.lookupType(i.Type) - if !ok { - panic(fmt.Errorf("no valid type for %q", i.Type)) +// decodeInterface decodes an interface value. +func (ds *decodeState) decodeInterface(ods *objectDecodeState, obj reflect.Value, encoded *wire.Interface) { + if _, ok := encoded.Type.(wire.TypeSpecNil); ok { + // Special case; the nil object. Just decode directly, which + // will read nil from the wire (if encoded correctly). + ds.decodeObject(ods, obj, encoded.Value) + return } - if obj.Kind() != reflect.Map { - // Set the obj to be the given typed value; this actually sets - // obj to be a non-zero value -- namely, it inserts type - // information. There's no need to do this for maps. - obj.Set(reflect.Zero(t)) + // We now need to resolve the actual type. + typ := ds.findType(encoded.Type) + + // We need to imbue type information here, then we can proceed to + // decode normally. In order to avoid issues with setting value-types, + // we create a new non-interface version of this object. We will then + // set the interface object to be equal to whatever we decode. + origObj := obj + obj = reflect.New(typ).Elem() + defer origObj.Set(obj) + + // With the object now having sufficient type information to actually + // have Set called on it, we can proceed to decode the value. + ds.decodeObject(ods, obj, encoded.Value) +} + +// isFloatEq determines if x and y represent the same value. +func isFloatEq(x float64, y float64) bool { + switch { + case math.IsNaN(x): + return math.IsNaN(y) + case math.IsInf(x, 1): + return math.IsInf(y, 1) + case math.IsInf(x, -1): + return math.IsInf(y, -1) + default: + return x == y } +} - // Decode the dereferenced element; there is no need to wait here, as - // the interface object shares the current object state. - ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type) +// isComplexEq determines if x and y represent the same value. +func isComplexEq(x complex128, y complex128) bool { + return isFloatEq(real(x), real(y)) && isFloatEq(imag(x), imag(y)) } // decodeObject decodes a object value. -func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) { - ds.push(false, format, param) - ds.stats.Add(obj) - ds.stats.Start(obj) - - switch x := object.GetValue().(type) { - case *pb.Object_BoolValue: - obj.SetBool(x.BoolValue) - case *pb.Object_StringValue: - obj.SetString(string(x.StringValue)) - case *pb.Object_Int64Value: - obj.SetInt(x.Int64Value) - if obj.Int() != x.Int64Value { - panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type())) +func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, encoded wire.Object) { + switch x := encoded.(type) { + case wire.Nil: // Fast path: first. + // We leave obj alone here. That's because if obj represents an + // interface, it may have been imbued with type information in + // decodeInterface, and we don't want to destroy that. + case *wire.Ref: + // Nil pointers may be encoded in a "forceValue" context. For + // those we just leave it alone as the value will already be + // correct (nil). + if id := objectID(x.Root); id == 0 { + return } - case *pb.Object_Uint64Value: - obj.SetUint(x.Uint64Value) - if obj.Uint() != x.Uint64Value { - panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type())) - } - case *pb.Object_DoubleValue: - obj.SetFloat(x.DoubleValue) - if obj.Float() != x.DoubleValue { - panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type())) - } - case *pb.Object_RefValue: - // Resolve the pointer itself, even though the object may not - // be decoded yet. You need to use wait() in order to ensure - // that is the case. See wait above, and Map.Barrier. - if id := x.RefValue; id != 0 { - // Decoding the interface should have imparted type - // information, so from this point it's safe to resolve - // and use this dynamic information for actually - // creating the object in register. - // - // (For non-interfaces this is a no-op). - dyntyp := reflect.TypeOf(obj.Interface()) - if dyntyp.Kind() == reflect.Map { - // Remove the map object count here to avoid - // double counting, as this object will be - // counted again when it gets processed later. - // We do not add a reference count as the - // reference is artificial. - ds.stats.Remove(obj) - obj.Set(ds.register(id, dyntyp).obj) - } else if dyntyp.Kind() == reflect.Ptr { - ds.push(true /* dereference */, "", nil) - obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) - ds.pop() - } else { - obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) + + // Note that if this is a map type, we go through a level of + // indirection to allow for map aliasing. + if obj.Kind() == reflect.Map { + v := ds.register(x, obj.Type()) + if v.IsNil() { + // Note that we don't want to clobber the map + // if has already been decoded by decodeMap. We + // just make it so that we have a consistent + // reference when that eventually does happen. + v.Set(reflect.MakeMap(v.Type())) } - } else { - // We leave obj alone here. That's because if obj - // represents an interface, it may have been embued - // with type information in decodeInterface, and we - // don't want to destroy that information. + obj.Set(v) + return } - case *pb.Object_SliceValue: - // It's okay to slice the array here, since the contents will - // still be provided later on. These semantics are a bit - // strange but they are handled in the Map.Barrier properly. - // - // The special semantics of zero ref apply here too. - if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 { - v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem()) - obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity))) + + // Normal assignment: authoritative only if no dots. + v := ds.register(x, obj.Type().Elem()) + if v.IsValid() { + obj.Set(unsafePointerTo(v)) } - case *pb.Object_ArrayValue: - ds.decodeArray(os, obj, x.ArrayValue) - case *pb.Object_StructValue: - ds.decodeStruct(os, obj, x.StructValue) - case *pb.Object_MapValue: - ds.decodeMap(os, obj, x.MapValue) - case *pb.Object_InterfaceValue: - ds.decodeInterface(os, obj, x.InterfaceValue) - case *pb.Object_ByteArrayValue: - copyArray(obj, reflect.ValueOf(x.ByteArrayValue)) - case *pb.Object_Uint16ArrayValue: - // 16-bit slices are serialized as 32-bit slices. - // See object.proto for details. - s := x.Uint16ArrayValue.Values - t := obj.Slice(0, obj.Len()).Interface().([]uint16) - if len(t) != len(s) { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) + case wire.Bool: + obj.SetBool(bool(x)) + case wire.Int: + obj.SetInt(int64(x)) + if obj.Int() != int64(x) { + Failf("signed integer truncated from %v to %v", int64(x), obj.Int()) } - for i := range s { - t[i] = uint16(s[i]) + case wire.Uint: + obj.SetUint(uint64(x)) + if obj.Uint() != uint64(x) { + Failf("unsigned integer truncated from %v to %v", uint64(x), obj.Uint()) } - case *pb.Object_Uint32ArrayValue: - copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values)) - case *pb.Object_Uint64ArrayValue: - copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values)) - case *pb.Object_UintptrArrayValue: - copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0)))) - case *pb.Object_Int8ArrayValue: - copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0)))) - case *pb.Object_Int16ArrayValue: - // 16-bit slices are serialized as 32-bit slices. - // See object.proto for details. - s := x.Int16ArrayValue.Values - t := obj.Slice(0, obj.Len()).Interface().([]int16) - if len(t) != len(s) { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) + case wire.Float32: + obj.SetFloat(float64(x)) + case wire.Float64: + obj.SetFloat(float64(x)) + if !isFloatEq(obj.Float(), float64(x)) { + Failf("floating point number truncated from %v to %v", float64(x), obj.Float()) } - for i := range s { - t[i] = int16(s[i]) + case *wire.Complex64: + obj.SetComplex(complex128(*x)) + case *wire.Complex128: + obj.SetComplex(complex128(*x)) + if !isComplexEq(obj.Complex(), complex128(*x)) { + Failf("complex number truncated from %v to %v", complex128(*x), obj.Complex()) } - case *pb.Object_Int32ArrayValue: - copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values)) - case *pb.Object_Int64ArrayValue: - copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values)) - case *pb.Object_BoolArrayValue: - copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values)) - case *pb.Object_Float64ArrayValue: - copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values)) - case *pb.Object_Float32ArrayValue: - copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values)) + case *wire.String: + obj.SetString(string(*x)) + case *wire.Slice: + // See *wire.Ref above; same applies. + if id := objectID(x.Ref.Root); id == 0 { + return + } + // Note that it's fine to slice the array here and assume that + // contents will still be filled in later on. + typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type. + v := ds.register(&x.Ref, typ) + obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity))) + case *wire.Array: + ds.decodeArray(ods, obj, x) + case *wire.Struct: + ds.decodeStruct(ods, obj, x) + case *wire.Map: + ds.decodeMap(ods, obj, x) + case *wire.Interface: + ds.decodeInterface(ods, obj, x) default: // Shoud not happen, not propagated as an error. - panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type())) - } - - ds.stats.Done() - ds.pop() -} - -func copyArray(dest reflect.Value, src reflect.Value) { - if dest.Len() != src.Len() { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len())) + Failf("unknown object %#v for %q", encoded, obj.Type().Name()) } - reflect.Copy(dest, castSlice(src, dest.Type().Elem())) } -// Deserialize deserializes the object state. +// Load deserializes the object graph rooted at obj. // // This function may panic and should be run in safely(). -func (ds *decodeState) Deserialize(obj reflect.Value) { - ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()} - ds.outstanding = 1 // The root object. +func (ds *decodeState) Load(obj reflect.Value) { + ds.stats.init() + defer ds.stats.fini(func(id typeID) string { + return ds.types.LookupName(id) + }) + + // Create the root object. + ds.objectsByID = append(ds.objectsByID, &objectDecodeState{ + id: 1, + obj: obj, + }) + + // Read the number of objects. + lastID, object, err := ReadHeader(ds.r) + if err != nil { + Failf("header error: %w", err) + } + if !object { + Failf("object missing") + } + + // Decode all objects. + var ( + encoded wire.Object + ods *objectDecodeState + id = objectID(1) + tid = typeID(1) + ) + if err := safely(func() { + // Decode all objects in the stream. + // + // Note that the structure of this decoding loop should match + // the raw decoding loop in printer.go. + for id <= objectID(lastID) { + // Unmarshal the object. + encoded = wire.Load(ds.r) + + // Is this a type object? Handle inline. + if wt, ok := encoded.(*wire.Type); ok { + ds.types.Register(wt) + tid++ + encoded = nil + continue + } - // Decode all objects in the stream. - // - // See above, we never process objects while we have no outstanding - // interests (other than the very first object). - for id := uint64(1); ds.outstanding > 0; id++ { - os := ds.lookup(id) - ds.stats.Start(os.obj) - - o, err := ds.readObject() - if err != nil { - panic(err) - } + // Actually resolve the object. + ods = ds.lookup(id) + if ods != nil { + // Decode the object. + ds.decodeObject(ods, ods.obj, encoded) + } else { + // If an object hasn't had interest registered + // previously or isn't yet valid, we deferred + // decoding until interest is registered. + ds.deferred[id] = encoded + } - if os != nil { - // Decode the object. - ds.from = &os.path - ds.decodeObject(os, os.obj, o, "", nil) - ds.outstanding-- + // For error handling. + ods = nil + encoded = nil + id++ + } + }); err != nil { + // Include as much information as we can, taking into account + // the possible state transitions above. + if ods != nil { + Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err) + } else if encoded != nil { + Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err) } else { - // If an object hasn't had interest registered - // previously, we deferred decoding until interest is - // registered. - ds.deferred[id] = o + Failf("general decoding error: %w", err) } - - ds.stats.Done() - } - - // Check the zero-length header at the end. - length, object, err := ReadHeader(ds.r) - if err != nil { - panic(err) - } - if length != 0 { - panic(fmt.Sprintf("expected zero-length terminal, got %d", length)) - } - if object { - panic("expected non-object terminal") } // Check if we have any deferred objects. - if count := len(ds.deferred); count > 0 { - // Shoud not happen, not propagated as an error. - panic(fmt.Sprintf("still have %d deferred objects", count)) - } - - // Scan and fire all callbacks. - for _, os := range ds.objectsByID { - os.checkComplete(ds.stats) + for id, encoded := range ds.deferred { + // Shoud never happen, the graph was bogus. + Failf("still have deferred objects: one is ID %d, %#v", id, encoded) } - // Check if we have any remaining dependency cycles. - for _, os := range ds.objectsByID { - if !os.complete() { - // This must be the result of a dependency cycle. - cycle := os.findCycle() - var buf bytes.Buffer - buf.WriteString("dependency cycle: {") - for i, cycleOS := range cycle { - if i > 0 { - buf.WriteString(" => ") + // Scan and fire all callbacks. We iterate over the list of incomplete + // objects until all have been finished. We stop iterating if no + // objects become complete (there is a dependency cycle). + // + // Note that we iterate backwards here, because there will be a strong + // tendendcy for blocking relationships to go from earlier objects to + // later (deeper) objects in the graph. This will reduce the number of + // iterations required to finish all objects. + if err := safely(func() { + for ds.pending.Back() != nil { + thisCycle := false + for ods = ds.pending.Back(); ods != nil; { + if ds.checkComplete(ods) { + thisCycle = true + break } - buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type())) + ods = ods.Prev() + } + if !thisCycle { + break } - buf.WriteString("}") - // Panic as an error; propagate to the caller. - panic(errors.New(string(buf.Bytes()))) } - } -} - -type byteReader struct { - io.Reader -} - -// ReadByte implements io.ByteReader. -func (br byteReader) ReadByte() (byte, error) { - var b [1]byte - n, err := br.Reader.Read(b[:]) - if n > 0 { - return b[0], nil - } else if err != nil { - return 0, err - } else { - return 0, io.ErrUnexpectedEOF + }); err != nil { + Failf("error executing callbacks for %#v: %w", ods.obj.Interface(), err) + } + + // Check if we have any remaining dependency cycles. If there are any + // objects left in the pending list, then it must be due to a cycle. + if ods := ds.pending.Front(); ods != nil { + // This must be the result of a dependency cycle. + cycle := ods.findCycle() + var buf bytes.Buffer + buf.WriteString("dependency cycle: {") + for i, cycleOS := range cycle { + if i > 0 { + buf.WriteString(" => ") + } + fmt.Fprintf(&buf, "%q", cycleOS.obj.Type()) + } + buf.WriteString("}") + Failf("incomplete graph: %s", string(buf.Bytes())) } } @@ -565,45 +706,20 @@ func (br byteReader) ReadByte() (byte, error) { // Each object written to the statefile is prefixed with a header. See // WriteHeader for more information; these functions are exported to allow // non-state writes to the file to play nice with debugging tools. -func ReadHeader(r io.Reader) (length uint64, object bool, err error) { +func ReadHeader(r wire.Reader) (length uint64, object bool, err error) { // Read the header. - length, err = binary.ReadUvarint(byteReader{r}) + err = safely(func() { + length = wire.LoadUint(r) + }) if err != nil { - return + // On the header, pass raw I/O errors. + if sErr, ok := err.(*ErrState); ok { + return 0, false, sErr.Unwrap() + } } // Decode whether the object is valid. - object = length&0x1 != 0 - length = length >> 1 + object = length&objectFlag != 0 + length &^= objectFlag return } - -// readObject reads an object from the stream. -func (ds *decodeState) readObject() (*pb.Object, error) { - // Read the header. - length, object, err := ReadHeader(ds.r) - if err != nil { - return nil, err - } - if !object { - return nil, fmt.Errorf("invalid object header") - } - - // Read the object. - buf := make([]byte, length) - for done := 0; done < len(buf); { - n, err := ds.r.Read(buf[done:]) - done += n - if n == 0 && err != nil { - return nil, err - } - } - - // Unmarshal. - obj := new(pb.Object) - if err := proto.Unmarshal(buf, obj); err != nil { - return nil, err - } - - return obj, nil -} |