// Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package state import ( "bytes" "encoding/binary" "errors" "fmt" "io" "reflect" "sort" "github.com/golang/protobuf/proto" pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto" ) // objectState represents an object that may be in the process of being // decoded. Specifically, it represents either a decoded object, or an an // interest in a future object that will be decoded. When that interest is // registered (via register), the storage for the object will be created, but // it will not be decoded until the object is encountered in the stream. type objectState struct { // id is the id for this object. // // If this field is zero, then this is an anonymous (unregistered, // non-reference primitive) object. This is immutable. id uint64 // obj is the object. This may or may not be valid yet, depending on // whether complete returns true. However, regardless of whether the // object is valid, obj contains a final storage location for the // object. This is immutable. // // Note that this must be addressable (obj.Addr() must not panic). // // The obj passed to the decode methods below will equal this obj only // in the case of decoding the top-level object. However, the passed // obj may represent individual fields, elements of a slice, etc. that // are effectively embedded within the reflect.Value below but with // distinct types. obj reflect.Value // blockedBy is the number of dependencies this object has. blockedBy int // blocking is a list of the objects blocked by this one. blocking []*objectState // callbacks is a set of callbacks to execute on load. callbacks []func() // path is the decoding path to the object. path recoverable } // complete indicates the object is complete. func (os *objectState) complete() bool { return os.blockedBy == 0 && len(os.callbacks) == 0 } // checkComplete checks for completion. If the object is complete, pending // callbacks will be executed and checkComplete will be called on downstream // objects (those depending on this one). func (os *objectState) checkComplete(stats *Stats) { if os.blockedBy > 0 { return } stats.Start(os.obj) // Fire all callbacks. for _, fn := range os.callbacks { fn() } os.callbacks = nil // Clear all blocked objects. for _, other := range os.blocking { other.blockedBy-- other.checkComplete(stats) } os.blocking = nil stats.Done() } // waitFor queues a dependency on the given object. func (os *objectState) waitFor(other *objectState, callback func()) { os.blockedBy++ other.blocking = append(other.blocking, os) if callback != nil { other.callbacks = append(other.callbacks, callback) } } // findCycleFor returns when the given object is found in the blocking set. func (os *objectState) findCycleFor(target *objectState) []*objectState { for _, other := range os.blocking { if other == target { return []*objectState{target} } else if childList := other.findCycleFor(target); childList != nil { return append(childList, other) } } return nil } // findCycle finds a dependency cycle. func (os *objectState) findCycle() []*objectState { return append(os.findCycleFor(os), os) } // decodeState is a graph of objects in the process of being decoded. // // The decode process involves loading the breadth-first graph generated by // encode. This graph is read in it's entirety, ensuring that all object // storage is complete. // // As the graph is being serialized, a set of completion callbacks are // executed. These completion callbacks should form a set of acyclic subgraphs // over the original one. After decoding is complete, the objects are scanned // to ensure that all callbacks are executed, otherwise the callback graph was // not acyclic. type decodeState struct { // objectByID is the set of objects in progress. objectsByID map[uint64]*objectState // deferred are objects that have been read, by no interest has been // registered yet. These will be decoded once interest in registered. deferred map[uint64]*pb.Object // outstanding is the number of outstanding objects. outstanding uint32 // r is the input stream. r io.Reader // stats is the passed stats object. stats *Stats // recoverable is the panic recover facility. recoverable } // lookup looks up an object in decodeState or returns nil if no such object // has been previously registered. func (ds *decodeState) lookup(id uint64) *objectState { return ds.objectsByID[id] } // wait registers a dependency on an object. // // As a special case, we always allow _useable_ references back to the first // decoding object because it may have fields that are already decoded. We also // allow trivial self reference, since they can be handled internally. func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) { switch id { case 0: // Nil pointer; nothing to wait for. fallthrough case waiter.id: // Trivial self reference. fallthrough case 1: // Root object; see above. if callback != nil { callback() } return } // No nil can be returned here. waiter.waitFor(ds.lookup(id), callback) } // waitObject notes a blocking relationship. func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) { if rv, ok := p.Value.(*pb.Object_RefValue); ok { // Refs can encode pointers and maps. ds.wait(os, rv.RefValue, callback) } else if sv, ok := p.Value.(*pb.Object_SliceValue); ok { // See decodeObject; we need to wait for the array (if non-nil). ds.wait(os, sv.SliceValue.RefValue, callback) } else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok { // It's an interface (wait recurisvely). ds.waitObject(os, iv.InterfaceValue.Value, callback) } else if callback != nil { // Nothing to wait for: execute the callback immediately. callback() } } // register registers a decode with a type. // // This type is only used to instantiate a new object if it has not been // registered previously. func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState { os, ok := ds.objectsByID[id] if ok { return os } // Record in the object index. if typ.Kind() == reflect.Map { os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()} } else { os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()} } ds.objectsByID[id] = os if o, ok := ds.deferred[id]; ok { // There is a deferred object. delete(ds.deferred, id) // Free memory. ds.decodeObject(os, os.obj, o, "", nil) } else { // There is no deferred object. ds.outstanding++ } return os } // decodeStruct decodes a struct value. func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) { // Set the fields. m := Map{newInternalMap(nil, ds, os)} defer internalMapPool.Put(m.internalMap) for _, field := range s.Fields { m.data = append(m.data, entry{ name: field.Name, object: field.Value, }) } // Sort the fields for efficient searching. // // Technically, these should already appear in sorted order in the // state ordering, so this cost is effectively a single scan to ensure // that the order is correct. if len(m.data) > 1 { sort.Slice(m.data, func(i, j int) bool { return m.data[i].name < m.data[j].name }) } // Invoke the load; this will recursively decode other objects. fns, ok := registeredTypes.lookupFns(obj.Addr().Type()) if ok { // Invoke the loader. fns.invokeLoad(obj.Addr(), m) } else if obj.NumField() == 0 { // Allow anonymous empty structs. return } else { // Propagate an error. panic(fmt.Errorf("unregistered type %s", obj.Type())) } } // decodeMap decodes a map value. func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) { if obj.IsNil() { obj.Set(reflect.MakeMap(obj.Type())) } for i := 0; i < len(m.Keys); i++ { // Decode the objects. kv := reflect.New(obj.Type().Key()).Elem() vv := reflect.New(obj.Type().Elem()).Elem() ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i) ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface()) ds.waitObject(os, m.Keys[i], nil) ds.waitObject(os, m.Values[i], nil) // Set in the map. obj.SetMapIndex(kv, vv) } } // decodeArray decodes an array value. func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) { if len(a.Contents) != obj.Len() { panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents))) } // Decode the contents into the array. for i := 0; i < len(a.Contents); i++ { ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i) ds.waitObject(os, a.Contents[i], nil) } } // decodeInterface decodes an interface value. func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) { // Is this a nil value? if i.Type == "" { return // Just leave obj alone. } // Get the dispatchable type. This may not be used if the given // reference has already been resolved, but if not we need to know the // type to create. t, ok := registeredTypes.lookupType(i.Type) if !ok { panic(fmt.Errorf("no valid type for %q", i.Type)) } if obj.Kind() != reflect.Map { // Set the obj to be the given typed value; this actually sets // obj to be a non-zero value -- namely, it inserts type // information. There's no need to do this for maps. obj.Set(reflect.Zero(t)) } // Decode the dereferenced element; there is no need to wait here, as // the interface object shares the current object state. ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type) } // decodeObject decodes a object value. func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) { ds.push(false, format, param) ds.stats.Add(obj) ds.stats.Start(obj) switch x := object.GetValue().(type) { case *pb.Object_BoolValue: obj.SetBool(x.BoolValue) case *pb.Object_StringValue: obj.SetString(string(x.StringValue)) case *pb.Object_Int64Value: obj.SetInt(x.Int64Value) if obj.Int() != x.Int64Value { panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type())) } case *pb.Object_Uint64Value: obj.SetUint(x.Uint64Value) if obj.Uint() != x.Uint64Value { panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type())) } case *pb.Object_DoubleValue: obj.SetFloat(x.DoubleValue) if obj.Float() != x.DoubleValue { panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type())) } case *pb.Object_RefValue: // Resolve the pointer itself, even though the object may not // be decoded yet. You need to use wait() in order to ensure // that is the case. See wait above, and Map.Barrier. if id := x.RefValue; id != 0 { // Decoding the interface should have imparted type // information, so from this point it's safe to resolve // and use this dynamic information for actually // creating the object in register. // // (For non-interfaces this is a no-op). dyntyp := reflect.TypeOf(obj.Interface()) if dyntyp.Kind() == reflect.Map { // Remove the map object count here to avoid // double counting, as this object will be // counted again when it gets processed later. // We do not add a reference count as the // reference is artificial. ds.stats.Remove(obj) obj.Set(ds.register(id, dyntyp).obj) } else if dyntyp.Kind() == reflect.Ptr { ds.push(true /* dereference */, "", nil) obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) ds.pop() } else { obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) } } else { // We leave obj alone here. That's because if obj // represents an interface, it may have been embued // with type information in decodeInterface, and we // don't want to destroy that information. } case *pb.Object_SliceValue: // It's okay to slice the array here, since the contents will // still be provided later on. These semantics are a bit // strange but they are handled in the Map.Barrier properly. // // The special semantics of zero ref apply here too. if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 { v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem()) obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity))) } case *pb.Object_ArrayValue: ds.decodeArray(os, obj, x.ArrayValue) case *pb.Object_StructValue: ds.decodeStruct(os, obj, x.StructValue) case *pb.Object_MapValue: ds.decodeMap(os, obj, x.MapValue) case *pb.Object_InterfaceValue: ds.decodeInterface(os, obj, x.InterfaceValue) case *pb.Object_ByteArrayValue: copyArray(obj, reflect.ValueOf(x.ByteArrayValue)) case *pb.Object_Uint16ArrayValue: // 16-bit slices are serialized as 32-bit slices. // See object.proto for details. s := x.Uint16ArrayValue.Values t := obj.Slice(0, obj.Len()).Interface().([]uint16) if len(t) != len(s) { panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) } for i := range s { t[i] = uint16(s[i]) } case *pb.Object_Uint32ArrayValue: copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values)) case *pb.Object_Uint64ArrayValue: copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values)) case *pb.Object_UintptrArrayValue: copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0)))) case *pb.Object_Int8ArrayValue: copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0)))) case *pb.Object_Int16ArrayValue: // 16-bit slices are serialized as 32-bit slices. // See object.proto for details. s := x.Int16ArrayValue.Values t := obj.Slice(0, obj.Len()).Interface().([]int16) if len(t) != len(s) { panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) } for i := range s { t[i] = int16(s[i]) } case *pb.Object_Int32ArrayValue: copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values)) case *pb.Object_Int64ArrayValue: copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values)) case *pb.Object_BoolArrayValue: copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values)) case *pb.Object_Float64ArrayValue: copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values)) case *pb.Object_Float32ArrayValue: copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values)) default: // Shoud not happen, not propagated as an error. panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type())) } ds.stats.Done() ds.pop() } func copyArray(dest reflect.Value, src reflect.Value) { if dest.Len() != src.Len() { panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len())) } reflect.Copy(dest, castSlice(src, dest.Type().Elem())) } // Deserialize deserializes the object state. // // This function may panic and should be run in safely(). func (ds *decodeState) Deserialize(obj reflect.Value) { ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()} ds.outstanding = 1 // The root object. // Decode all objects in the stream. // // See above, we never process objects while we have no outstanding // interests (other than the very first object). for id := uint64(1); ds.outstanding > 0; id++ { os := ds.lookup(id) ds.stats.Start(os.obj) o, err := ds.readObject() if err != nil { panic(err) } if os != nil { // Decode the object. ds.from = &os.path ds.decodeObject(os, os.obj, o, "", nil) ds.outstanding-- } else { // If an object hasn't had interest registered // previously, we deferred decoding until interest is // registered. ds.deferred[id] = o } ds.stats.Done() } // Check the zero-length header at the end. length, object, err := ReadHeader(ds.r) if err != nil { panic(err) } if length != 0 { panic(fmt.Sprintf("expected zero-length terminal, got %d", length)) } if object { panic("expected non-object terminal") } // Check if we have any deferred objects. if count := len(ds.deferred); count > 0 { // Shoud not happen, not propagated as an error. panic(fmt.Sprintf("still have %d deferred objects", count)) } // Scan and fire all callbacks. for _, os := range ds.objectsByID { os.checkComplete(ds.stats) } // Check if we have any remaining dependency cycles. for _, os := range ds.objectsByID { if !os.complete() { // This must be the result of a dependency cycle. cycle := os.findCycle() var buf bytes.Buffer buf.WriteString("dependency cycle: {") for i, cycleOS := range cycle { if i > 0 { buf.WriteString(" => ") } buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type())) } buf.WriteString("}") // Panic as an error; propagate to the caller. panic(errors.New(string(buf.Bytes()))) } } } type byteReader struct { io.Reader } // ReadByte implements io.ByteReader. func (br byteReader) ReadByte() (byte, error) { var b [1]byte n, err := br.Reader.Read(b[:]) if n > 0 { return b[0], nil } else if err != nil { return 0, err } else { return 0, io.ErrUnexpectedEOF } } // ReadHeader reads an object header. // // Each object written to the statefile is prefixed with a header. See // WriteHeader for more information; these functions are exported to allow // non-state writes to the file to play nice with debugging tools. func ReadHeader(r io.Reader) (length uint64, object bool, err error) { // Read the header. length, err = binary.ReadUvarint(byteReader{r}) if err != nil { return } // Decode whether the object is valid. object = length&0x1 != 0 length = length >> 1 return } // readObject reads an object from the stream. func (ds *decodeState) readObject() (*pb.Object, error) { // Read the header. length, object, err := ReadHeader(ds.r) if err != nil { return nil, err } if !object { return nil, fmt.Errorf("invalid object header") } // Read the object. buf := make([]byte, length) for done := 0; done < len(buf); { n, err := ds.r.Read(buf[done:]) done += n if n == 0 && err != nil { return nil, err } } // Unmarshal. obj := new(pb.Object) if err := proto.Unmarshal(buf, obj); err != nil { return nil, err } return obj, nil }