// Copyright 2018 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package state import ( "bytes" "context" "fmt" "math" "reflect" "gvisor.dev/gvisor/pkg/state/wire" ) // internalCallback is a interface called on object completion. // // There are two implementations: objectDecodeState & userCallback. type internalCallback interface { // source returns the dependent object. May be nil. source() *objectDecodeState // callbackRun executes the callback. callbackRun() } // userCallback is an implementation of internalCallback. type userCallback func() // source implements internalCallback.source. func (userCallback) source() *objectDecodeState { return nil } // callbackRun implements internalCallback.callbackRun. func (uc userCallback) callbackRun() { uc() } // objectDecodeState represents an object that may be in the process of being // decoded. Specifically, it represents either a decoded object, or an an // interest in a future object that will be decoded. When that interest is // registered (via register), the storage for the object will be created, but // it will not be decoded until the object is encountered in the stream. type objectDecodeState struct { // id is the id for this object. id objectID // typ is the id for this typeID. This may be zero if this is not a // type-registered structure. typ typeID // obj is the object. This may or may not be valid yet, depending on // whether complete returns true. However, regardless of whether the // object is valid, obj contains a final storage location for the // object. This is immutable. // // Note that this must be addressable (obj.Addr() must not panic). // // The obj passed to the decode methods below will equal this obj only // in the case of decoding the top-level object. However, the passed // obj may represent individual fields, elements of a slice, etc. that // are effectively embedded within the reflect.Value below but with // distinct types. obj reflect.Value // blockedBy is the number of dependencies this object has. blockedBy int // callbacksInline is inline storage for callbacks. callbacksInline [2]internalCallback // callbacks is a set of callbacks to execute on load. callbacks []internalCallback completeEntry } // addCallback adds a callback to the objectDecodeState. func (ods *objectDecodeState) addCallback(ic internalCallback) { if ods.callbacks == nil { ods.callbacks = ods.callbacksInline[:0] } ods.callbacks = append(ods.callbacks, ic) } // findCycleFor returns when the given object is found in the blocking set. func (ods *objectDecodeState) findCycleFor(target *objectDecodeState) []*objectDecodeState { for _, ic := range ods.callbacks { other := ic.source() if other != nil && other == target { return []*objectDecodeState{target} } else if childList := other.findCycleFor(target); childList != nil { return append(childList, other) } } // This should not occur. Failf("no deadlock found?") panic("unreachable") } // findCycle finds a dependency cycle. func (ods *objectDecodeState) findCycle() []*objectDecodeState { return append(ods.findCycleFor(ods), ods) } // source implements internalCallback.source. func (ods *objectDecodeState) source() *objectDecodeState { return ods } // callbackRun implements internalCallback.callbackRun. func (ods *objectDecodeState) callbackRun() { ods.blockedBy-- } // decodeState is a graph of objects in the process of being decoded. // // The decode process involves loading the breadth-first graph generated by // encode. This graph is read in it's entirety, ensuring that all object // storage is complete. // // As the graph is being serialized, a set of completion callbacks are // executed. These completion callbacks should form a set of acyclic subgraphs // over the original one. After decoding is complete, the objects are scanned // to ensure that all callbacks are executed, otherwise the callback graph was // not acyclic. type decodeState struct { // ctx is the decode context. ctx context.Context // r is the input stream. r wire.Reader // types is the type database. types typeDecodeDatabase // objectByID is the set of objects in progress. objectsByID []*objectDecodeState // deferred are objects that have been read, by no interest has been // registered yet. These will be decoded once interest in registered. deferred map[objectID]wire.Object // pending is the set of objects that are not yet complete. pending completeList // stats tracks time data. stats Stats } // lookup looks up an object in decodeState or returns nil if no such object // has been previously registered. func (ds *decodeState) lookup(id objectID) *objectDecodeState { if len(ds.objectsByID) < int(id) { return nil } return ds.objectsByID[id-1] } // checkComplete checks for completion. func (ds *decodeState) checkComplete(ods *objectDecodeState) bool { // Still blocked? if ods.blockedBy > 0 { return false } // Track stats if relevant. if ods.callbacks != nil && ods.typ != 0 { ds.stats.start(ods.typ) defer ds.stats.done() } // Fire all callbacks. for _, ic := range ods.callbacks { ic.callbackRun() } // Mark completed. cbs := ods.callbacks ods.callbacks = nil ds.pending.Remove(ods) // Recursively check others. for _, ic := range cbs { if other := ic.source(); other != nil && other.blockedBy == 0 { ds.checkComplete(other) } } return true // All set. } // wait registers a dependency on an object. // // As a special case, we always allow _useable_ references back to the first // decoding object because it may have fields that are already decoded. We also // allow trivial self reference, since they can be handled internally. func (ds *decodeState) wait(waiter *objectDecodeState, id objectID, callback func()) { switch id { case waiter.id: // Trivial self reference. fallthrough case 1: // Root object; see above. if callback != nil { callback() } return } // Mark as blocked. waiter.blockedBy++ // No nil can be returned here. other := ds.lookup(id) if callback != nil { // Add the additional user callback. other.addCallback(userCallback(callback)) } // Mark waiter as unblocked. other.addCallback(waiter) } // waitObject notes a blocking relationship. func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, callback func()) { if rv, ok := encoded.(*wire.Ref); ok && rv.Root != 0 { // Refs can encode pointers and maps. ds.wait(ods, objectID(rv.Root), callback) } else if sv, ok := encoded.(*wire.Slice); ok && sv.Ref.Root != 0 { // See decodeObject; we need to wait for the array (if non-nil). ds.wait(ods, objectID(sv.Ref.Root), callback) } else if iv, ok := encoded.(*wire.Interface); ok { // It's an interface (wait recurisvely). ds.waitObject(ods, iv.Value, callback) } else if callback != nil { // Nothing to wait for: execute the callback immediately. callback() } } // walkChild returns a child object from obj, given an accessor path. This is // the decode-side equivalent to traverse in encode.go. // // For the purposes of this function, a child object is either a field within a // struct or an array element, with one such indirection per element in // path. The returned value may be an unexported field, so it may not be // directly assignable. See unsafePointerTo. func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value { // See wire.Ref.Dots. The path here is specified in reverse order. for i := len(path) - 1; i >= 0; i-- { switch pc := path[i].(type) { case *wire.FieldName: // Must be a pointer. if obj.Kind() != reflect.Struct { Failf("next component in child path is a field name, but the current object is not a struct. Path: %v, current obj: %#v", path, obj) } obj = obj.FieldByName(string(*pc)) case wire.Index: // Embedded. if obj.Kind() != reflect.Array { Failf("next component in child path is an array index, but the current object is not an array. Path: %v, current obj: %#v", path, obj) } obj = obj.Index(int(pc)) default: panic("unreachable: switch should be exhaustive") } } return obj } // register registers a decode with a type. // // This type is only used to instantiate a new object if it has not been // registered previously. This depends on the type provided if none is // available in the object itself. func (ds *decodeState) register(r *wire.Ref, typ reflect.Type) reflect.Value { // Grow the objectsByID slice. id := objectID(r.Root) if len(ds.objectsByID) < int(id) { ds.objectsByID = append(ds.objectsByID, make([]*objectDecodeState, int(id)-len(ds.objectsByID))...) } // Does this object already exist? ods := ds.objectsByID[id-1] if ods != nil { return walkChild(r.Dots, ods.obj) } // Create the object. if len(r.Dots) != 0 { typ = ds.findType(r.Type) } v := reflect.New(typ) ods = &objectDecodeState{ id: id, obj: v.Elem(), } ds.objectsByID[id-1] = ods ds.pending.PushBack(ods) // Process any deferred objects & callbacks. if encoded, ok := ds.deferred[id]; ok { delete(ds.deferred, id) ds.decodeObject(ods, ods.obj, encoded) } return walkChild(r.Dots, ods.obj) } // objectDecoder is for decoding structs. type objectDecoder struct { // ds is decodeState. ds *decodeState // ods is current object being decoded. ods *objectDecodeState // reconciledTypeEntry is the reconciled type information. rte *reconciledTypeEntry // encoded is the encoded object state. encoded *wire.Struct } // load is helper for the public methods on Source. func (od *objectDecoder) load(slot int, objPtr reflect.Value, wait bool, fn func()) { // Note that we have reconciled the type and may remap the fields here // to match what's expected by the decoder. The "slot" parameter here // is in terms of the local type, where the fields in the encoded // object are in terms of the wire object's type, which might be in a // different order (but will have the same fields). v := *od.encoded.Field(od.rte.FieldOrder[slot]) od.ds.decodeObject(od.ods, objPtr.Elem(), v) if wait { // Mark this individual object a blocker. od.ds.waitObject(od.ods, v, fn) } } // aterLoad implements Source.AfterLoad. func (od *objectDecoder) afterLoad(fn func()) { // Queue the local callback; this will execute when all of the above // data dependencies have been cleared. od.ods.addCallback(userCallback(fn)) } // decodeStruct decodes a struct value. func (ds *decodeState) decodeStruct(ods *objectDecodeState, obj reflect.Value, encoded *wire.Struct) { if encoded.TypeID == 0 { // Allow anonymous empty structs, but only if the encoded // object also has no fields. if encoded.Fields() == 0 && obj.NumField() == 0 { return } // Propagate an error. Failf("empty struct on wire %#v has field mismatch with type %q", encoded, obj.Type().Name()) } // Lookup the object type. rte := ds.types.Lookup(typeID(encoded.TypeID), obj.Type()) ods.typ = typeID(encoded.TypeID) // Invoke the loader. od := objectDecoder{ ds: ds, ods: ods, rte: rte, encoded: encoded, } ds.stats.start(ods.typ) defer ds.stats.done() if sl, ok := obj.Addr().Interface().(SaverLoader); ok { // Note: may be a registered empty struct which does not // implement the saver/loader interfaces. sl.StateLoad(Source{internal: od}) } } // decodeMap decodes a map value. func (ds *decodeState) decodeMap(ods *objectDecodeState, obj reflect.Value, encoded *wire.Map) { if obj.IsNil() { // See pointerTo. obj.Set(reflect.MakeMap(obj.Type())) } for i := 0; i < len(encoded.Keys); i++ { // Decode the objects. kv := reflect.New(obj.Type().Key()).Elem() vv := reflect.New(obj.Type().Elem()).Elem() ds.decodeObject(ods, kv, encoded.Keys[i]) ds.decodeObject(ods, vv, encoded.Values[i]) ds.waitObject(ods, encoded.Keys[i], nil) ds.waitObject(ods, encoded.Values[i], nil) // Set in the map. obj.SetMapIndex(kv, vv) } } // decodeArray decodes an array value. func (ds *decodeState) decodeArray(ods *objectDecodeState, obj reflect.Value, encoded *wire.Array) { if len(encoded.Contents) != obj.Len() { Failf("mismatching array length expect=%d, actual=%d", obj.Len(), len(encoded.Contents)) } // Decode the contents into the array. for i := 0; i < len(encoded.Contents); i++ { ds.decodeObject(ods, obj.Index(i), encoded.Contents[i]) ds.waitObject(ods, encoded.Contents[i], nil) } } // findType finds the type for the given wire.TypeSpecs. func (ds *decodeState) findType(t wire.TypeSpec) reflect.Type { switch x := t.(type) { case wire.TypeID: typ := ds.types.LookupType(typeID(x)) rte := ds.types.Lookup(typeID(x), typ) return rte.LocalType case *wire.TypeSpecPointer: return reflect.PtrTo(ds.findType(x.Type)) case *wire.TypeSpecArray: return reflect.ArrayOf(int(x.Count), ds.findType(x.Type)) case *wire.TypeSpecSlice: return reflect.SliceOf(ds.findType(x.Type)) case *wire.TypeSpecMap: return reflect.MapOf(ds.findType(x.Key), ds.findType(x.Value)) default: // Should not happen. Failf("unknown type %#v", t) } panic("unreachable") } // decodeInterface decodes an interface value. func (ds *decodeState) decodeInterface(ods *objectDecodeState, obj reflect.Value, encoded *wire.Interface) { if _, ok := encoded.Type.(wire.TypeSpecNil); ok { // Special case; the nil object. Just decode directly, which // will read nil from the wire (if encoded correctly). ds.decodeObject(ods, obj, encoded.Value) return } // We now need to resolve the actual type. typ := ds.findType(encoded.Type) // We need to imbue type information here, then we can proceed to // decode normally. In order to avoid issues with setting value-types, // we create a new non-interface version of this object. We will then // set the interface object to be equal to whatever we decode. origObj := obj obj = reflect.New(typ).Elem() defer origObj.Set(obj) // With the object now having sufficient type information to actually // have Set called on it, we can proceed to decode the value. ds.decodeObject(ods, obj, encoded.Value) } // isFloatEq determines if x and y represent the same value. func isFloatEq(x float64, y float64) bool { switch { case math.IsNaN(x): return math.IsNaN(y) case math.IsInf(x, 1): return math.IsInf(y, 1) case math.IsInf(x, -1): return math.IsInf(y, -1) default: return x == y } } // isComplexEq determines if x and y represent the same value. func isComplexEq(x complex128, y complex128) bool { return isFloatEq(real(x), real(y)) && isFloatEq(imag(x), imag(y)) } // decodeObject decodes a object value. func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, encoded wire.Object) { switch x := encoded.(type) { case wire.Nil: // Fast path: first. // We leave obj alone here. That's because if obj represents an // interface, it may have been imbued with type information in // decodeInterface, and we don't want to destroy that. case *wire.Ref: // Nil pointers may be encoded in a "forceValue" context. For // those we just leave it alone as the value will already be // correct (nil). if id := objectID(x.Root); id == 0 { return } // Note that if this is a map type, we go through a level of // indirection to allow for map aliasing. if obj.Kind() == reflect.Map { v := ds.register(x, obj.Type()) if v.IsNil() { // Note that we don't want to clobber the map // if has already been decoded by decodeMap. We // just make it so that we have a consistent // reference when that eventually does happen. v.Set(reflect.MakeMap(v.Type())) } obj.Set(v) return } // Normal assignment: authoritative only if no dots. v := ds.register(x, obj.Type().Elem()) if v.IsValid() { obj.Set(unsafePointerTo(v)) } case wire.Bool: obj.SetBool(bool(x)) case wire.Int: obj.SetInt(int64(x)) if obj.Int() != int64(x) { Failf("signed integer truncated from %v to %v", int64(x), obj.Int()) } case wire.Uint: obj.SetUint(uint64(x)) if obj.Uint() != uint64(x) { Failf("unsigned integer truncated from %v to %v", uint64(x), obj.Uint()) } case wire.Float32: obj.SetFloat(float64(x)) case wire.Float64: obj.SetFloat(float64(x)) if !isFloatEq(obj.Float(), float64(x)) { Failf("floating point number truncated from %v to %v", float64(x), obj.Float()) } case *wire.Complex64: obj.SetComplex(complex128(*x)) case *wire.Complex128: obj.SetComplex(complex128(*x)) if !isComplexEq(obj.Complex(), complex128(*x)) { Failf("complex number truncated from %v to %v", complex128(*x), obj.Complex()) } case *wire.String: obj.SetString(string(*x)) case *wire.Slice: // See *wire.Ref above; same applies. if id := objectID(x.Ref.Root); id == 0 { return } // Note that it's fine to slice the array here and assume that // contents will still be filled in later on. typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type. v := ds.register(&x.Ref, typ) obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity))) case *wire.Array: ds.decodeArray(ods, obj, x) case *wire.Struct: ds.decodeStruct(ods, obj, x) case *wire.Map: ds.decodeMap(ods, obj, x) case *wire.Interface: ds.decodeInterface(ods, obj, x) default: // Shoud not happen, not propagated as an error. Failf("unknown object %#v for %q", encoded, obj.Type().Name()) } } // Load deserializes the object graph rooted at obj. // // This function may panic and should be run in safely(). func (ds *decodeState) Load(obj reflect.Value) { ds.stats.init() defer ds.stats.fini(func(id typeID) string { return ds.types.LookupName(id) }) // Create the root object. ds.objectsByID = append(ds.objectsByID, &objectDecodeState{ id: 1, obj: obj, }) // Read the number of objects. lastID, object, err := ReadHeader(ds.r) if err != nil { Failf("header error: %w", err) } if !object { Failf("object missing") } // Decode all objects. var ( encoded wire.Object ods *objectDecodeState id = objectID(1) tid = typeID(1) ) if err := safely(func() { // Decode all objects in the stream. // // Note that the structure of this decoding loop should match // the raw decoding loop in printer.go. for id <= objectID(lastID) { // Unmarshal the object. encoded = wire.Load(ds.r) // Is this a type object? Handle inline. if wt, ok := encoded.(*wire.Type); ok { ds.types.Register(wt) tid++ encoded = nil continue } // Actually resolve the object. ods = ds.lookup(id) if ods != nil { // Decode the object. ds.decodeObject(ods, ods.obj, encoded) } else { // If an object hasn't had interest registered // previously or isn't yet valid, we deferred // decoding until interest is registered. ds.deferred[id] = encoded } // For error handling. ods = nil encoded = nil id++ } }); err != nil { // Include as much information as we can, taking into account // the possible state transitions above. if ods != nil { Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err) } else if encoded != nil { Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err) } else { Failf("general decoding error: %w", err) } } // Check if we have any deferred objects. for id, encoded := range ds.deferred { // Shoud never happen, the graph was bogus. Failf("still have deferred objects: one is ID %d, %#v", id, encoded) } // Scan and fire all callbacks. We iterate over the list of incomplete // objects until all have been finished. We stop iterating if no // objects become complete (there is a dependency cycle). // // Note that we iterate backwards here, because there will be a strong // tendendcy for blocking relationships to go from earlier objects to // later (deeper) objects in the graph. This will reduce the number of // iterations required to finish all objects. if err := safely(func() { for ds.pending.Back() != nil { thisCycle := false for ods = ds.pending.Back(); ods != nil; { if ds.checkComplete(ods) { thisCycle = true break } ods = ods.Prev() } if !thisCycle { break } } }); err != nil { Failf("error executing callbacks for %#v: %w", ods.obj.Interface(), err) } // Check if we have any remaining dependency cycles. If there are any // objects left in the pending list, then it must be due to a cycle. if ods := ds.pending.Front(); ods != nil { // This must be the result of a dependency cycle. cycle := ods.findCycle() var buf bytes.Buffer buf.WriteString("dependency cycle: {") for i, cycleOS := range cycle { if i > 0 { buf.WriteString(" => ") } fmt.Fprintf(&buf, "%q", cycleOS.obj.Type()) } buf.WriteString("}") Failf("incomplete graph: %s", string(buf.Bytes())) } } // ReadHeader reads an object header. // // Each object written to the statefile is prefixed with a header. See // WriteHeader for more information; these functions are exported to allow // non-state writes to the file to play nice with debugging tools. func ReadHeader(r wire.Reader) (length uint64, object bool, err error) { // Read the header. err = safely(func() { length = wire.LoadUint(r) }) if err != nil { // On the header, pass raw I/O errors. if sErr, ok := err.(*ErrState); ok { return 0, false, sErr.Unwrap() } } // Decode whether the object is valid. object = length&objectFlag != 0 length &^= objectFlag return }