summaryrefslogtreecommitdiffhomepage
path: root/pkg/state/decode.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/state/decode.go')
-rw-r--r--pkg/state/decode.go918
1 files changed, 517 insertions, 401 deletions
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
index 590c241a3..c9971cdf6 100644
--- a/pkg/state/decode.go
+++ b/pkg/state/decode.go
@@ -17,28 +17,49 @@ package state
import (
"bytes"
"context"
- "encoding/binary"
- "errors"
"fmt"
- "io"
+ "math"
"reflect"
- "sort"
- "github.com/golang/protobuf/proto"
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
+ "gvisor.dev/gvisor/pkg/state/wire"
)
-// objectState represents an object that may be in the process of being
+// internalCallback is a interface called on object completion.
+//
+// There are two implementations: objectDecodeState & userCallback.
+type internalCallback interface {
+ // source returns the dependent object. May be nil.
+ source() *objectDecodeState
+
+ // callbackRun executes the callback.
+ callbackRun()
+}
+
+// userCallback is an implementation of internalCallback.
+type userCallback func()
+
+// source implements internalCallback.source.
+func (userCallback) source() *objectDecodeState {
+ return nil
+}
+
+// callbackRun implements internalCallback.callbackRun.
+func (uc userCallback) callbackRun() {
+ uc()
+}
+
+// objectDecodeState represents an object that may be in the process of being
// decoded. Specifically, it represents either a decoded object, or an an
// interest in a future object that will be decoded. When that interest is
// registered (via register), the storage for the object will be created, but
// it will not be decoded until the object is encountered in the stream.
-type objectState struct {
+type objectDecodeState struct {
// id is the id for this object.
- //
- // If this field is zero, then this is an anonymous (unregistered,
- // non-reference primitive) object. This is immutable.
- id uint64
+ id objectID
+
+ // typ is the id for this typeID. This may be zero if this is not a
+ // type-registered structure.
+ typ typeID
// obj is the object. This may or may not be valid yet, depending on
// whether complete returns true. However, regardless of whether the
@@ -57,69 +78,52 @@ type objectState struct {
// blockedBy is the number of dependencies this object has.
blockedBy int
- // blocking is a list of the objects blocked by this one.
- blocking []*objectState
+ // callbacksInline is inline storage for callbacks.
+ callbacksInline [2]internalCallback
// callbacks is a set of callbacks to execute on load.
- callbacks []func()
-
- // path is the decoding path to the object.
- path recoverable
-}
-
-// complete indicates the object is complete.
-func (os *objectState) complete() bool {
- return os.blockedBy == 0 && len(os.callbacks) == 0
-}
-
-// checkComplete checks for completion. If the object is complete, pending
-// callbacks will be executed and checkComplete will be called on downstream
-// objects (those depending on this one).
-func (os *objectState) checkComplete(stats *Stats) {
- if os.blockedBy > 0 {
- return
- }
- stats.Start(os.obj)
+ callbacks []internalCallback
- // Fire all callbacks.
- for _, fn := range os.callbacks {
- fn()
- }
- os.callbacks = nil
-
- // Clear all blocked objects.
- for _, other := range os.blocking {
- other.blockedBy--
- other.checkComplete(stats)
- }
- os.blocking = nil
- stats.Done()
+ completeEntry
}
-// waitFor queues a dependency on the given object.
-func (os *objectState) waitFor(other *objectState, callback func()) {
- os.blockedBy++
- other.blocking = append(other.blocking, os)
- if callback != nil {
- other.callbacks = append(other.callbacks, callback)
+// addCallback adds a callback to the objectDecodeState.
+func (ods *objectDecodeState) addCallback(ic internalCallback) {
+ if ods.callbacks == nil {
+ ods.callbacks = ods.callbacksInline[:0]
}
+ ods.callbacks = append(ods.callbacks, ic)
}
// findCycleFor returns when the given object is found in the blocking set.
-func (os *objectState) findCycleFor(target *objectState) []*objectState {
- for _, other := range os.blocking {
- if other == target {
- return []*objectState{target}
+func (ods *objectDecodeState) findCycleFor(target *objectDecodeState) []*objectDecodeState {
+ for _, ic := range ods.callbacks {
+ other := ic.source()
+ if other != nil && other == target {
+ return []*objectDecodeState{target}
} else if childList := other.findCycleFor(target); childList != nil {
return append(childList, other)
}
}
- return nil
+
+ // This should not occur.
+ Failf("no deadlock found?")
+ panic("unreachable")
}
// findCycle finds a dependency cycle.
-func (os *objectState) findCycle() []*objectState {
- return append(os.findCycleFor(os), os)
+func (ods *objectDecodeState) findCycle() []*objectDecodeState {
+ return append(ods.findCycleFor(ods), ods)
+}
+
+// source implements internalCallback.source.
+func (ods *objectDecodeState) source() *objectDecodeState {
+ return ods
+}
+
+// callbackRun implements internalCallback.callbackRun.
+func (ods *objectDecodeState) callbackRun() {
+ ods.blockedBy--
}
// decodeState is a graph of objects in the process of being decoded.
@@ -137,30 +141,66 @@ type decodeState struct {
// ctx is the decode context.
ctx context.Context
+ // r is the input stream.
+ r wire.Reader
+
+ // types is the type database.
+ types typeDecodeDatabase
+
// objectByID is the set of objects in progress.
- objectsByID map[uint64]*objectState
+ objectsByID []*objectDecodeState
// deferred are objects that have been read, by no interest has been
// registered yet. These will be decoded once interest in registered.
- deferred map[uint64]*pb.Object
+ deferred map[objectID]wire.Object
- // outstanding is the number of outstanding objects.
- outstanding uint32
+ // pending is the set of objects that are not yet complete.
+ pending completeList
- // r is the input stream.
- r io.Reader
-
- // stats is the passed stats object.
- stats *Stats
-
- // recoverable is the panic recover facility.
- recoverable
+ // stats tracks time data.
+ stats Stats
}
// lookup looks up an object in decodeState or returns nil if no such object
// has been previously registered.
-func (ds *decodeState) lookup(id uint64) *objectState {
- return ds.objectsByID[id]
+func (ds *decodeState) lookup(id objectID) *objectDecodeState {
+ if len(ds.objectsByID) < int(id) {
+ return nil
+ }
+ return ds.objectsByID[id-1]
+}
+
+// checkComplete checks for completion.
+func (ds *decodeState) checkComplete(ods *objectDecodeState) bool {
+ // Still blocked?
+ if ods.blockedBy > 0 {
+ return false
+ }
+
+ // Track stats if relevant.
+ if ods.callbacks != nil && ods.typ != 0 {
+ ds.stats.start(ods.typ)
+ defer ds.stats.done()
+ }
+
+ // Fire all callbacks.
+ for _, ic := range ods.callbacks {
+ ic.callbackRun()
+ }
+
+ // Mark completed.
+ cbs := ods.callbacks
+ ods.callbacks = nil
+ ds.pending.Remove(ods)
+
+ // Recursively check others.
+ for _, ic := range cbs {
+ if other := ic.source(); other != nil && other.blockedBy == 0 {
+ ds.checkComplete(other)
+ }
+ }
+
+ return true // All set.
}
// wait registers a dependency on an object.
@@ -168,11 +208,8 @@ func (ds *decodeState) lookup(id uint64) *objectState {
// As a special case, we always allow _useable_ references back to the first
// decoding object because it may have fields that are already decoded. We also
// allow trivial self reference, since they can be handled internally.
-func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) {
+func (ds *decodeState) wait(waiter *objectDecodeState, id objectID, callback func()) {
switch id {
- case 0:
- // Nil pointer; nothing to wait for.
- fallthrough
case waiter.id:
// Trivial self reference.
fallthrough
@@ -184,107 +221,188 @@ func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) {
return
}
+ // Mark as blocked.
+ waiter.blockedBy++
+
// No nil can be returned here.
- waiter.waitFor(ds.lookup(id), callback)
+ other := ds.lookup(id)
+ if callback != nil {
+ // Add the additional user callback.
+ other.addCallback(userCallback(callback))
+ }
+
+ // Mark waiter as unblocked.
+ other.addCallback(waiter)
}
// waitObject notes a blocking relationship.
-func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) {
- if rv, ok := p.Value.(*pb.Object_RefValue); ok {
+func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, callback func()) {
+ if rv, ok := encoded.(*wire.Ref); ok && rv.Root != 0 {
// Refs can encode pointers and maps.
- ds.wait(os, rv.RefValue, callback)
- } else if sv, ok := p.Value.(*pb.Object_SliceValue); ok {
+ ds.wait(ods, objectID(rv.Root), callback)
+ } else if sv, ok := encoded.(*wire.Slice); ok && sv.Ref.Root != 0 {
// See decodeObject; we need to wait for the array (if non-nil).
- ds.wait(os, sv.SliceValue.RefValue, callback)
- } else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok {
+ ds.wait(ods, objectID(sv.Ref.Root), callback)
+ } else if iv, ok := encoded.(*wire.Interface); ok {
// It's an interface (wait recurisvely).
- ds.waitObject(os, iv.InterfaceValue.Value, callback)
+ ds.waitObject(ods, iv.Value, callback)
} else if callback != nil {
// Nothing to wait for: execute the callback immediately.
callback()
}
}
+// walkChild returns a child object from obj, given an accessor path. This is
+// the decode-side equivalent to traverse in encode.go.
+//
+// For the purposes of this function, a child object is either a field within a
+// struct or an array element, with one such indirection per element in
+// path. The returned value may be an unexported field, so it may not be
+// directly assignable. See unsafePointerTo.
+func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value {
+ // See wire.Ref.Dots. The path here is specified in reverse order.
+ for i := len(path) - 1; i >= 0; i-- {
+ switch pc := path[i].(type) {
+ case *wire.FieldName: // Must be a pointer.
+ if obj.Kind() != reflect.Struct {
+ Failf("next component in child path is a field name, but the current object is not a struct. Path: %v, current obj: %#v", path, obj)
+ }
+ obj = obj.FieldByName(string(*pc))
+ case wire.Index: // Embedded.
+ if obj.Kind() != reflect.Array {
+ Failf("next component in child path is an array index, but the current object is not an array. Path: %v, current obj: %#v", path, obj)
+ }
+ obj = obj.Index(int(pc))
+ default:
+ panic("unreachable: switch should be exhaustive")
+ }
+ }
+ return obj
+}
+
// register registers a decode with a type.
//
// This type is only used to instantiate a new object if it has not been
-// registered previously.
-func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState {
- os, ok := ds.objectsByID[id]
- if ok {
- return os
+// registered previously. This depends on the type provided if none is
+// available in the object itself.
+func (ds *decodeState) register(r *wire.Ref, typ reflect.Type) reflect.Value {
+ // Grow the objectsByID slice.
+ id := objectID(r.Root)
+ if len(ds.objectsByID) < int(id) {
+ ds.objectsByID = append(ds.objectsByID, make([]*objectDecodeState, int(id)-len(ds.objectsByID))...)
+ }
+
+ // Does this object already exist?
+ ods := ds.objectsByID[id-1]
+ if ods != nil {
+ return walkChild(r.Dots, ods.obj)
+ }
+
+ // Create the object.
+ if len(r.Dots) != 0 {
+ typ = ds.findType(r.Type)
}
+ v := reflect.New(typ)
+ ods = &objectDecodeState{
+ id: id,
+ obj: v.Elem(),
+ }
+ ds.objectsByID[id-1] = ods
+ ds.pending.PushBack(ods)
- // Record in the object index.
- if typ.Kind() == reflect.Map {
- os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()}
- } else {
- os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()}
+ // Process any deferred objects & callbacks.
+ if encoded, ok := ds.deferred[id]; ok {
+ delete(ds.deferred, id)
+ ds.decodeObject(ods, ods.obj, encoded)
}
- ds.objectsByID[id] = os
- if o, ok := ds.deferred[id]; ok {
- // There is a deferred object.
- delete(ds.deferred, id) // Free memory.
- ds.decodeObject(os, os.obj, o, "", nil)
- } else {
- // There is no deferred object.
- ds.outstanding++
+ return walkChild(r.Dots, ods.obj)
+}
+
+// objectDecoder is for decoding structs.
+type objectDecoder struct {
+ // ds is decodeState.
+ ds *decodeState
+
+ // ods is current object being decoded.
+ ods *objectDecodeState
+
+ // reconciledTypeEntry is the reconciled type information.
+ rte *reconciledTypeEntry
+
+ // encoded is the encoded object state.
+ encoded *wire.Struct
+}
+
+// load is helper for the public methods on Source.
+func (od *objectDecoder) load(slot int, objPtr reflect.Value, wait bool, fn func()) {
+ // Note that we have reconciled the type and may remap the fields here
+ // to match what's expected by the decoder. The "slot" parameter here
+ // is in terms of the local type, where the fields in the encoded
+ // object are in terms of the wire object's type, which might be in a
+ // different order (but will have the same fields).
+ v := *od.encoded.Field(od.rte.FieldOrder[slot])
+ od.ds.decodeObject(od.ods, objPtr.Elem(), v)
+ if wait {
+ // Mark this individual object a blocker.
+ od.ds.waitObject(od.ods, v, fn)
}
+}
- return os
+// aterLoad implements Source.AfterLoad.
+func (od *objectDecoder) afterLoad(fn func()) {
+ // Queue the local callback; this will execute when all of the above
+ // data dependencies have been cleared.
+ od.ods.addCallback(userCallback(fn))
}
// decodeStruct decodes a struct value.
-func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) {
- // Set the fields.
- m := Map{newInternalMap(nil, ds, os)}
- defer internalMapPool.Put(m.internalMap)
- for _, field := range s.Fields {
- m.data = append(m.data, entry{
- name: field.Name,
- object: field.Value,
- })
- }
-
- // Sort the fields for efficient searching.
- //
- // Technically, these should already appear in sorted order in the
- // state ordering, so this cost is effectively a single scan to ensure
- // that the order is correct.
- if len(m.data) > 1 {
- sort.Slice(m.data, func(i, j int) bool {
- return m.data[i].name < m.data[j].name
- })
- }
-
- // Invoke the load; this will recursively decode other objects.
- fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
- if ok {
- // Invoke the loader.
- fns.invokeLoad(obj.Addr(), m)
- } else if obj.NumField() == 0 {
- // Allow anonymous empty structs.
- return
- } else {
+func (ds *decodeState) decodeStruct(ods *objectDecodeState, obj reflect.Value, encoded *wire.Struct) {
+ if encoded.TypeID == 0 {
+ // Allow anonymous empty structs, but only if the encoded
+ // object also has no fields.
+ if encoded.Fields() == 0 && obj.NumField() == 0 {
+ return
+ }
+
// Propagate an error.
- panic(fmt.Errorf("unregistered type %s", obj.Type()))
+ Failf("empty struct on wire %#v has field mismatch with type %q", encoded, obj.Type().Name())
+ }
+
+ // Lookup the object type.
+ rte := ds.types.Lookup(typeID(encoded.TypeID), obj.Type())
+ ods.typ = typeID(encoded.TypeID)
+
+ // Invoke the loader.
+ od := objectDecoder{
+ ds: ds,
+ ods: ods,
+ rte: rte,
+ encoded: encoded,
+ }
+ ds.stats.start(ods.typ)
+ defer ds.stats.done()
+ if sl, ok := obj.Addr().Interface().(SaverLoader); ok {
+ // Note: may be a registered empty struct which does not
+ // implement the saver/loader interfaces.
+ sl.StateLoad(Source{internal: od})
}
}
// decodeMap decodes a map value.
-func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) {
+func (ds *decodeState) decodeMap(ods *objectDecodeState, obj reflect.Value, encoded *wire.Map) {
if obj.IsNil() {
+ // See pointerTo.
obj.Set(reflect.MakeMap(obj.Type()))
}
- for i := 0; i < len(m.Keys); i++ {
+ for i := 0; i < len(encoded.Keys); i++ {
// Decode the objects.
kv := reflect.New(obj.Type().Key()).Elem()
vv := reflect.New(obj.Type().Elem()).Elem()
- ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i)
- ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface())
- ds.waitObject(os, m.Keys[i], nil)
- ds.waitObject(os, m.Values[i], nil)
+ ds.decodeObject(ods, kv, encoded.Keys[i])
+ ds.decodeObject(ods, vv, encoded.Values[i])
+ ds.waitObject(ods, encoded.Keys[i], nil)
+ ds.waitObject(ods, encoded.Values[i], nil)
// Set in the map.
obj.SetMapIndex(kv, vv)
@@ -292,271 +410,294 @@ func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map)
}
// decodeArray decodes an array value.
-func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) {
- if len(a.Contents) != obj.Len() {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents)))
+func (ds *decodeState) decodeArray(ods *objectDecodeState, obj reflect.Value, encoded *wire.Array) {
+ if len(encoded.Contents) != obj.Len() {
+ Failf("mismatching array length expect=%d, actual=%d", obj.Len(), len(encoded.Contents))
}
// Decode the contents into the array.
- for i := 0; i < len(a.Contents); i++ {
- ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i)
- ds.waitObject(os, a.Contents[i], nil)
+ for i := 0; i < len(encoded.Contents); i++ {
+ ds.decodeObject(ods, obj.Index(i), encoded.Contents[i])
+ ds.waitObject(ods, encoded.Contents[i], nil)
}
}
-// decodeInterface decodes an interface value.
-func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) {
- // Is this a nil value?
- if i.Type == "" {
- return // Just leave obj alone.
+// findType finds the type for the given wire.TypeSpecs.
+func (ds *decodeState) findType(t wire.TypeSpec) reflect.Type {
+ switch x := t.(type) {
+ case wire.TypeID:
+ typ := ds.types.LookupType(typeID(x))
+ rte := ds.types.Lookup(typeID(x), typ)
+ return rte.LocalType
+ case *wire.TypeSpecPointer:
+ return reflect.PtrTo(ds.findType(x.Type))
+ case *wire.TypeSpecArray:
+ return reflect.ArrayOf(int(x.Count), ds.findType(x.Type))
+ case *wire.TypeSpecSlice:
+ return reflect.SliceOf(ds.findType(x.Type))
+ case *wire.TypeSpecMap:
+ return reflect.MapOf(ds.findType(x.Key), ds.findType(x.Value))
+ default:
+ // Should not happen.
+ Failf("unknown type %#v", t)
}
+ panic("unreachable")
+}
- // Get the dispatchable type. This may not be used if the given
- // reference has already been resolved, but if not we need to know the
- // type to create.
- t, ok := registeredTypes.lookupType(i.Type)
- if !ok {
- panic(fmt.Errorf("no valid type for %q", i.Type))
+// decodeInterface decodes an interface value.
+func (ds *decodeState) decodeInterface(ods *objectDecodeState, obj reflect.Value, encoded *wire.Interface) {
+ if _, ok := encoded.Type.(wire.TypeSpecNil); ok {
+ // Special case; the nil object. Just decode directly, which
+ // will read nil from the wire (if encoded correctly).
+ ds.decodeObject(ods, obj, encoded.Value)
+ return
}
- if obj.Kind() != reflect.Map {
- // Set the obj to be the given typed value; this actually sets
- // obj to be a non-zero value -- namely, it inserts type
- // information. There's no need to do this for maps.
- obj.Set(reflect.Zero(t))
+ // We now need to resolve the actual type.
+ typ := ds.findType(encoded.Type)
+
+ // We need to imbue type information here, then we can proceed to
+ // decode normally. In order to avoid issues with setting value-types,
+ // we create a new non-interface version of this object. We will then
+ // set the interface object to be equal to whatever we decode.
+ origObj := obj
+ obj = reflect.New(typ).Elem()
+ defer origObj.Set(obj)
+
+ // With the object now having sufficient type information to actually
+ // have Set called on it, we can proceed to decode the value.
+ ds.decodeObject(ods, obj, encoded.Value)
+}
+
+// isFloatEq determines if x and y represent the same value.
+func isFloatEq(x float64, y float64) bool {
+ switch {
+ case math.IsNaN(x):
+ return math.IsNaN(y)
+ case math.IsInf(x, 1):
+ return math.IsInf(y, 1)
+ case math.IsInf(x, -1):
+ return math.IsInf(y, -1)
+ default:
+ return x == y
}
+}
- // Decode the dereferenced element; there is no need to wait here, as
- // the interface object shares the current object state.
- ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type)
+// isComplexEq determines if x and y represent the same value.
+func isComplexEq(x complex128, y complex128) bool {
+ return isFloatEq(real(x), real(y)) && isFloatEq(imag(x), imag(y))
}
// decodeObject decodes a object value.
-func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) {
- ds.push(false, format, param)
- ds.stats.Add(obj)
- ds.stats.Start(obj)
-
- switch x := object.GetValue().(type) {
- case *pb.Object_BoolValue:
- obj.SetBool(x.BoolValue)
- case *pb.Object_StringValue:
- obj.SetString(string(x.StringValue))
- case *pb.Object_Int64Value:
- obj.SetInt(x.Int64Value)
- if obj.Int() != x.Int64Value {
- panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type()))
+func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, encoded wire.Object) {
+ switch x := encoded.(type) {
+ case wire.Nil: // Fast path: first.
+ // We leave obj alone here. That's because if obj represents an
+ // interface, it may have been imbued with type information in
+ // decodeInterface, and we don't want to destroy that.
+ case *wire.Ref:
+ // Nil pointers may be encoded in a "forceValue" context. For
+ // those we just leave it alone as the value will already be
+ // correct (nil).
+ if id := objectID(x.Root); id == 0 {
+ return
}
- case *pb.Object_Uint64Value:
- obj.SetUint(x.Uint64Value)
- if obj.Uint() != x.Uint64Value {
- panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type()))
- }
- case *pb.Object_DoubleValue:
- obj.SetFloat(x.DoubleValue)
- if obj.Float() != x.DoubleValue {
- panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type()))
- }
- case *pb.Object_RefValue:
- // Resolve the pointer itself, even though the object may not
- // be decoded yet. You need to use wait() in order to ensure
- // that is the case. See wait above, and Map.Barrier.
- if id := x.RefValue; id != 0 {
- // Decoding the interface should have imparted type
- // information, so from this point it's safe to resolve
- // and use this dynamic information for actually
- // creating the object in register.
- //
- // (For non-interfaces this is a no-op).
- dyntyp := reflect.TypeOf(obj.Interface())
- if dyntyp.Kind() == reflect.Map {
- // Remove the map object count here to avoid
- // double counting, as this object will be
- // counted again when it gets processed later.
- // We do not add a reference count as the
- // reference is artificial.
- ds.stats.Remove(obj)
- obj.Set(ds.register(id, dyntyp).obj)
- } else if dyntyp.Kind() == reflect.Ptr {
- ds.push(true /* dereference */, "", nil)
- obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
- ds.pop()
- } else {
- obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
+
+ // Note that if this is a map type, we go through a level of
+ // indirection to allow for map aliasing.
+ if obj.Kind() == reflect.Map {
+ v := ds.register(x, obj.Type())
+ if v.IsNil() {
+ // Note that we don't want to clobber the map
+ // if has already been decoded by decodeMap. We
+ // just make it so that we have a consistent
+ // reference when that eventually does happen.
+ v.Set(reflect.MakeMap(v.Type()))
}
- } else {
- // We leave obj alone here. That's because if obj
- // represents an interface, it may have been embued
- // with type information in decodeInterface, and we
- // don't want to destroy that information.
+ obj.Set(v)
+ return
}
- case *pb.Object_SliceValue:
- // It's okay to slice the array here, since the contents will
- // still be provided later on. These semantics are a bit
- // strange but they are handled in the Map.Barrier properly.
- //
- // The special semantics of zero ref apply here too.
- if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 {
- v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem())
- obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity)))
+
+ // Normal assignment: authoritative only if no dots.
+ v := ds.register(x, obj.Type().Elem())
+ if v.IsValid() {
+ obj.Set(unsafePointerTo(v))
}
- case *pb.Object_ArrayValue:
- ds.decodeArray(os, obj, x.ArrayValue)
- case *pb.Object_StructValue:
- ds.decodeStruct(os, obj, x.StructValue)
- case *pb.Object_MapValue:
- ds.decodeMap(os, obj, x.MapValue)
- case *pb.Object_InterfaceValue:
- ds.decodeInterface(os, obj, x.InterfaceValue)
- case *pb.Object_ByteArrayValue:
- copyArray(obj, reflect.ValueOf(x.ByteArrayValue))
- case *pb.Object_Uint16ArrayValue:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := x.Uint16ArrayValue.Values
- t := obj.Slice(0, obj.Len()).Interface().([]uint16)
- if len(t) != len(s) {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
+ case wire.Bool:
+ obj.SetBool(bool(x))
+ case wire.Int:
+ obj.SetInt(int64(x))
+ if obj.Int() != int64(x) {
+ Failf("signed integer truncated from %v to %v", int64(x), obj.Int())
}
- for i := range s {
- t[i] = uint16(s[i])
+ case wire.Uint:
+ obj.SetUint(uint64(x))
+ if obj.Uint() != uint64(x) {
+ Failf("unsigned integer truncated from %v to %v", uint64(x), obj.Uint())
}
- case *pb.Object_Uint32ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values))
- case *pb.Object_Uint64ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values))
- case *pb.Object_UintptrArrayValue:
- copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0))))
- case *pb.Object_Int8ArrayValue:
- copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0))))
- case *pb.Object_Int16ArrayValue:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := x.Int16ArrayValue.Values
- t := obj.Slice(0, obj.Len()).Interface().([]int16)
- if len(t) != len(s) {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
+ case wire.Float32:
+ obj.SetFloat(float64(x))
+ case wire.Float64:
+ obj.SetFloat(float64(x))
+ if !isFloatEq(obj.Float(), float64(x)) {
+ Failf("floating point number truncated from %v to %v", float64(x), obj.Float())
}
- for i := range s {
- t[i] = int16(s[i])
+ case *wire.Complex64:
+ obj.SetComplex(complex128(*x))
+ case *wire.Complex128:
+ obj.SetComplex(complex128(*x))
+ if !isComplexEq(obj.Complex(), complex128(*x)) {
+ Failf("complex number truncated from %v to %v", complex128(*x), obj.Complex())
}
- case *pb.Object_Int32ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values))
- case *pb.Object_Int64ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values))
- case *pb.Object_BoolArrayValue:
- copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values))
- case *pb.Object_Float64ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values))
- case *pb.Object_Float32ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values))
+ case *wire.String:
+ obj.SetString(string(*x))
+ case *wire.Slice:
+ // See *wire.Ref above; same applies.
+ if id := objectID(x.Ref.Root); id == 0 {
+ return
+ }
+ // Note that it's fine to slice the array here and assume that
+ // contents will still be filled in later on.
+ typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type.
+ v := ds.register(&x.Ref, typ)
+ obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity)))
+ case *wire.Array:
+ ds.decodeArray(ods, obj, x)
+ case *wire.Struct:
+ ds.decodeStruct(ods, obj, x)
+ case *wire.Map:
+ ds.decodeMap(ods, obj, x)
+ case *wire.Interface:
+ ds.decodeInterface(ods, obj, x)
default:
// Shoud not happen, not propagated as an error.
- panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type()))
- }
-
- ds.stats.Done()
- ds.pop()
-}
-
-func copyArray(dest reflect.Value, src reflect.Value) {
- if dest.Len() != src.Len() {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len()))
+ Failf("unknown object %#v for %q", encoded, obj.Type().Name())
}
- reflect.Copy(dest, castSlice(src, dest.Type().Elem()))
}
-// Deserialize deserializes the object state.
+// Load deserializes the object graph rooted at obj.
//
// This function may panic and should be run in safely().
-func (ds *decodeState) Deserialize(obj reflect.Value) {
- ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()}
- ds.outstanding = 1 // The root object.
+func (ds *decodeState) Load(obj reflect.Value) {
+ ds.stats.init()
+ defer ds.stats.fini(func(id typeID) string {
+ return ds.types.LookupName(id)
+ })
+
+ // Create the root object.
+ ds.objectsByID = append(ds.objectsByID, &objectDecodeState{
+ id: 1,
+ obj: obj,
+ })
+
+ // Read the number of objects.
+ lastID, object, err := ReadHeader(ds.r)
+ if err != nil {
+ Failf("header error: %w", err)
+ }
+ if !object {
+ Failf("object missing")
+ }
+
+ // Decode all objects.
+ var (
+ encoded wire.Object
+ ods *objectDecodeState
+ id = objectID(1)
+ tid = typeID(1)
+ )
+ if err := safely(func() {
+ // Decode all objects in the stream.
+ //
+ // Note that the structure of this decoding loop should match
+ // the raw decoding loop in printer.go.
+ for id <= objectID(lastID) {
+ // Unmarshal the object.
+ encoded = wire.Load(ds.r)
+
+ // Is this a type object? Handle inline.
+ if wt, ok := encoded.(*wire.Type); ok {
+ ds.types.Register(wt)
+ tid++
+ encoded = nil
+ continue
+ }
- // Decode all objects in the stream.
- //
- // See above, we never process objects while we have no outstanding
- // interests (other than the very first object).
- for id := uint64(1); ds.outstanding > 0; id++ {
- os := ds.lookup(id)
- ds.stats.Start(os.obj)
-
- o, err := ds.readObject()
- if err != nil {
- panic(err)
- }
+ // Actually resolve the object.
+ ods = ds.lookup(id)
+ if ods != nil {
+ // Decode the object.
+ ds.decodeObject(ods, ods.obj, encoded)
+ } else {
+ // If an object hasn't had interest registered
+ // previously or isn't yet valid, we deferred
+ // decoding until interest is registered.
+ ds.deferred[id] = encoded
+ }
- if os != nil {
- // Decode the object.
- ds.from = &os.path
- ds.decodeObject(os, os.obj, o, "", nil)
- ds.outstanding--
+ // For error handling.
+ ods = nil
+ encoded = nil
+ id++
+ }
+ }); err != nil {
+ // Include as much information as we can, taking into account
+ // the possible state transitions above.
+ if ods != nil {
+ Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err)
+ } else if encoded != nil {
+ Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err)
} else {
- // If an object hasn't had interest registered
- // previously, we deferred decoding until interest is
- // registered.
- ds.deferred[id] = o
+ Failf("general decoding error: %w", err)
}
-
- ds.stats.Done()
- }
-
- // Check the zero-length header at the end.
- length, object, err := ReadHeader(ds.r)
- if err != nil {
- panic(err)
- }
- if length != 0 {
- panic(fmt.Sprintf("expected zero-length terminal, got %d", length))
- }
- if object {
- panic("expected non-object terminal")
}
// Check if we have any deferred objects.
- if count := len(ds.deferred); count > 0 {
- // Shoud not happen, not propagated as an error.
- panic(fmt.Sprintf("still have %d deferred objects", count))
- }
-
- // Scan and fire all callbacks.
- for _, os := range ds.objectsByID {
- os.checkComplete(ds.stats)
+ for id, encoded := range ds.deferred {
+ // Shoud never happen, the graph was bogus.
+ Failf("still have deferred objects: one is ID %d, %#v", id, encoded)
}
- // Check if we have any remaining dependency cycles.
- for _, os := range ds.objectsByID {
- if !os.complete() {
- // This must be the result of a dependency cycle.
- cycle := os.findCycle()
- var buf bytes.Buffer
- buf.WriteString("dependency cycle: {")
- for i, cycleOS := range cycle {
- if i > 0 {
- buf.WriteString(" => ")
+ // Scan and fire all callbacks. We iterate over the list of incomplete
+ // objects until all have been finished. We stop iterating if no
+ // objects become complete (there is a dependency cycle).
+ //
+ // Note that we iterate backwards here, because there will be a strong
+ // tendendcy for blocking relationships to go from earlier objects to
+ // later (deeper) objects in the graph. This will reduce the number of
+ // iterations required to finish all objects.
+ if err := safely(func() {
+ for ds.pending.Back() != nil {
+ thisCycle := false
+ for ods = ds.pending.Back(); ods != nil; {
+ if ds.checkComplete(ods) {
+ thisCycle = true
+ break
}
- buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type()))
+ ods = ods.Prev()
+ }
+ if !thisCycle {
+ break
}
- buf.WriteString("}")
- // Panic as an error; propagate to the caller.
- panic(errors.New(string(buf.Bytes())))
}
- }
-}
-
-type byteReader struct {
- io.Reader
-}
-
-// ReadByte implements io.ByteReader.
-func (br byteReader) ReadByte() (byte, error) {
- var b [1]byte
- n, err := br.Reader.Read(b[:])
- if n > 0 {
- return b[0], nil
- } else if err != nil {
- return 0, err
- } else {
- return 0, io.ErrUnexpectedEOF
+ }); err != nil {
+ Failf("error executing callbacks for %#v: %w", ods.obj.Interface(), err)
+ }
+
+ // Check if we have any remaining dependency cycles. If there are any
+ // objects left in the pending list, then it must be due to a cycle.
+ if ods := ds.pending.Front(); ods != nil {
+ // This must be the result of a dependency cycle.
+ cycle := ods.findCycle()
+ var buf bytes.Buffer
+ buf.WriteString("dependency cycle: {")
+ for i, cycleOS := range cycle {
+ if i > 0 {
+ buf.WriteString(" => ")
+ }
+ fmt.Fprintf(&buf, "%q", cycleOS.obj.Type())
+ }
+ buf.WriteString("}")
+ Failf("incomplete graph: %s", string(buf.Bytes()))
}
}
@@ -565,45 +706,20 @@ func (br byteReader) ReadByte() (byte, error) {
// Each object written to the statefile is prefixed with a header. See
// WriteHeader for more information; these functions are exported to allow
// non-state writes to the file to play nice with debugging tools.
-func ReadHeader(r io.Reader) (length uint64, object bool, err error) {
+func ReadHeader(r wire.Reader) (length uint64, object bool, err error) {
// Read the header.
- length, err = binary.ReadUvarint(byteReader{r})
+ err = safely(func() {
+ length = wire.LoadUint(r)
+ })
if err != nil {
- return
+ // On the header, pass raw I/O errors.
+ if sErr, ok := err.(*ErrState); ok {
+ return 0, false, sErr.Unwrap()
+ }
}
// Decode whether the object is valid.
- object = length&0x1 != 0
- length = length >> 1
+ object = length&objectFlag != 0
+ length &^= objectFlag
return
}
-
-// readObject reads an object from the stream.
-func (ds *decodeState) readObject() (*pb.Object, error) {
- // Read the header.
- length, object, err := ReadHeader(ds.r)
- if err != nil {
- return nil, err
- }
- if !object {
- return nil, fmt.Errorf("invalid object header")
- }
-
- // Read the object.
- buf := make([]byte, length)
- for done := 0; done < len(buf); {
- n, err := ds.r.Read(buf[done:])
- done += n
- if n == 0 && err != nil {
- return nil, err
- }
- }
-
- // Unmarshal.
- obj := new(pb.Object)
- if err := proto.Unmarshal(buf, obj); err != nil {
- return nil, err
- }
-
- return obj, nil
-}