diff options
Diffstat (limited to 'pkg/state/decode.go')
-rw-r--r-- | pkg/state/decode.go | 594 |
1 files changed, 594 insertions, 0 deletions
diff --git a/pkg/state/decode.go b/pkg/state/decode.go new file mode 100644 index 000000000..05758495b --- /dev/null +++ b/pkg/state/decode.go @@ -0,0 +1,594 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "reflect" + "sort" + + "github.com/golang/protobuf/proto" + pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto" +) + +// objectState represents an object that may be in the process of being +// decoded. Specifically, it represents either a decoded object, or an an +// interest in a future object that will be decoded. When that interest is +// registered (via register), the storage for the object will be created, but +// it will not be decoded until the object is encountered in the stream. +type objectState struct { + // id is the id for this object. + // + // If this field is zero, then this is an anonymous (unregistered, + // non-reference primitive) object. This is immutable. + id uint64 + + // obj is the object. This may or may not be valid yet, depending on + // whether complete returns true. However, regardless of whether the + // object is valid, obj contains a final storage location for the + // object. This is immutable. + // + // Note that this must be addressable (obj.Addr() must not panic). + // + // The obj passed to the decode methods below will equal this obj only + // in the case of decoding the top-level object. However, the passed + // obj may represent individual fields, elements of a slice, etc. that + // are effectively embedded within the reflect.Value below but with + // distinct types. + obj reflect.Value + + // blockedBy is the number of dependencies this object has. + blockedBy int + + // blocking is a list of the objects blocked by this one. + blocking []*objectState + + // callbacks is a set of callbacks to execute on load. + callbacks []func() + + // path is the decoding path to the object. + path recoverable +} + +// complete indicates the object is complete. +func (os *objectState) complete() bool { + return os.blockedBy == 0 && len(os.callbacks) == 0 +} + +// checkComplete checks for completion. If the object is complete, pending +// callbacks will be executed and checkComplete will be called on downstream +// objects (those depending on this one). +func (os *objectState) checkComplete(stats *Stats) { + if os.blockedBy > 0 { + return + } + + // Fire all callbacks. + for _, fn := range os.callbacks { + stats.Start(os.obj) + fn() + stats.Done() + } + os.callbacks = nil + + // Clear all blocked objects. + for _, other := range os.blocking { + other.blockedBy-- + other.checkComplete(stats) + } + os.blocking = nil +} + +// waitFor queues a dependency on the given object. +func (os *objectState) waitFor(other *objectState, callback func()) { + os.blockedBy++ + other.blocking = append(other.blocking, os) + if callback != nil { + other.callbacks = append(other.callbacks, callback) + } +} + +// findCycleFor returns when the given object is found in the blocking set. +func (os *objectState) findCycleFor(target *objectState) []*objectState { + for _, other := range os.blocking { + if other == target { + return []*objectState{target} + } else if childList := other.findCycleFor(target); childList != nil { + return append(childList, other) + } + } + return nil +} + +// findCycle finds a dependency cycle. +func (os *objectState) findCycle() []*objectState { + return append(os.findCycleFor(os), os) +} + +// decodeState is a graph of objects in the process of being decoded. +// +// The decode process involves loading the breadth-first graph generated by +// encode. This graph is read in it's entirety, ensuring that all object +// storage is complete. +// +// As the graph is being serialized, a set of completion callbacks are +// executed. These completion callbacks should form a set of acyclic subgraphs +// over the original one. After decoding is complete, the objects are scanned +// to ensure that all callbacks are executed, otherwise the callback graph was +// not acyclic. +type decodeState struct { + // objectByID is the set of objects in progress. + objectsByID map[uint64]*objectState + + // deferred are objects that have been read, by no interest has been + // registered yet. These will be decoded once interest in registered. + deferred map[uint64]*pb.Object + + // outstanding is the number of outstanding objects. + outstanding uint32 + + // r is the input stream. + r io.Reader + + // stats is the passed stats object. + stats *Stats + + // recoverable is the panic recover facility. + recoverable +} + +// lookup looks up an object in decodeState or returns nil if no such object +// has been previously registered. +func (ds *decodeState) lookup(id uint64) *objectState { + return ds.objectsByID[id] +} + +// wait registers a dependency on an object. +// +// As a special case, we always allow _useable_ references back to the first +// decoding object because it may have fields that are already decoded. We also +// allow trivial self reference, since they can be handled internally. +func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) { + switch id { + case 0: + // Nil pointer; nothing to wait for. + fallthrough + case waiter.id: + // Trivial self reference. + fallthrough + case 1: + // Root object; see above. + if callback != nil { + callback() + } + return + } + + // No nil can be returned here. + waiter.waitFor(ds.lookup(id), callback) +} + +// waitObject notes a blocking relationship. +func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) { + if rv, ok := p.Value.(*pb.Object_RefValue); ok { + // Refs can encode pointers and maps. + ds.wait(os, rv.RefValue, callback) + } else if sv, ok := p.Value.(*pb.Object_SliceValue); ok { + // See decodeObject; we need to wait for the array (if non-nil). + ds.wait(os, sv.SliceValue.RefValue, callback) + } else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok { + // It's an interface (wait recurisvely). + ds.waitObject(os, iv.InterfaceValue.Value, callback) + } else if callback != nil { + // Nothing to wait for: execute the callback immediately. + callback() + } +} + +// register registers a decode with a type. +// +// This type is only used to instantiate a new object if it has not been +// registered previously. +func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState { + os, ok := ds.objectsByID[id] + if ok { + return os + } + + // Record in the object index. + if typ.Kind() == reflect.Map { + os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()} + } else { + os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()} + } + ds.objectsByID[id] = os + + if o, ok := ds.deferred[id]; ok { + // There is a deferred object. + delete(ds.deferred, id) // Free memory. + ds.decodeObject(os, os.obj, o, "", nil) + } else { + // There is no deferred object. + ds.outstanding++ + } + + return os +} + +// decodeStruct decodes a struct value. +func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) { + // Set the fields. + m := Map{newInternalMap(nil, ds, os)} + defer internalMapPool.Put(m.internalMap) + for _, field := range s.Fields { + m.data = append(m.data, entry{ + name: field.Name, + object: field.Value, + }) + } + + // Sort the fields for efficient searching. + // + // Technically, these should already appear in sorted order in the + // state ordering, so this cost is effectively a single scan to ensure + // that the order is correct. + if len(m.data) > 1 { + sort.Slice(m.data, func(i, j int) bool { + return m.data[i].name < m.data[j].name + }) + } + + // Invoke the load; this will recursively decode other objects. + fns, ok := registeredTypes.lookupFns(obj.Addr().Type()) + if ok { + // Invoke the loader. + fns.invokeLoad(obj.Addr(), m) + } else if obj.NumField() == 0 { + // Allow anonymous empty structs. + return + } else { + // Propagate an error. + panic(fmt.Errorf("unregistered type %s", obj.Type())) + } +} + +// decodeMap decodes a map value. +func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) { + if obj.IsNil() { + obj.Set(reflect.MakeMap(obj.Type())) + } + for i := 0; i < len(m.Keys); i++ { + // Decode the objects. + kv := reflect.New(obj.Type().Key()).Elem() + vv := reflect.New(obj.Type().Elem()).Elem() + ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i) + ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface()) + ds.waitObject(os, m.Keys[i], nil) + ds.waitObject(os, m.Values[i], nil) + + // Set in the map. + obj.SetMapIndex(kv, vv) + } +} + +// decodeArray decodes an array value. +func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) { + if len(a.Contents) != obj.Len() { + panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents))) + } + // Decode the contents into the array. + for i := 0; i < len(a.Contents); i++ { + ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i) + ds.waitObject(os, a.Contents[i], nil) + } +} + +// decodeInterface decodes an interface value. +func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) { + // Is this a nil value? + if i.Type == "" { + return // Just leave obj alone. + } + + // Get the dispatchable type. This may not be used if the given + // reference has already been resolved, but if not we need to know the + // type to create. + t, ok := registeredTypes.lookupType(i.Type) + if !ok { + panic(fmt.Errorf("no valid type for %q", i.Type)) + } + + if obj.Kind() != reflect.Map { + // Set the obj to be the given typed value; this actually sets + // obj to be a non-zero value -- namely, it inserts type + // information. There's no need to do this for maps. + obj.Set(reflect.Zero(t)) + } + + // Decode the dereferenced element; there is no need to wait here, as + // the interface object shares the current object state. + ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type) +} + +// decodeObject decodes a object value. +func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) { + ds.push(false, format, param) + ds.stats.Start(obj) + + switch x := object.GetValue().(type) { + case *pb.Object_BoolValue: + obj.SetBool(x.BoolValue) + case *pb.Object_StringValue: + obj.SetString(x.StringValue) + case *pb.Object_Int64Value: + obj.SetInt(x.Int64Value) + if obj.Int() != x.Int64Value { + panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type())) + } + case *pb.Object_Uint64Value: + obj.SetUint(x.Uint64Value) + if obj.Uint() != x.Uint64Value { + panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type())) + } + case *pb.Object_DoubleValue: + obj.SetFloat(x.DoubleValue) + if obj.Float() != x.DoubleValue { + panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type())) + } + case *pb.Object_RefValue: + // Resolve the pointer itself, even though the object may not + // be decoded yet. You need to use wait() in order to ensure + // that is the case. See wait above, and Map.Barrier. + if id := x.RefValue; id != 0 { + // Decoding the interface should have imparted type + // information, so from this point it's safe to resolve + // and use this dynamic information for actually + // creating the object in register. + // + // (For non-interfaces this is a no-op). + dyntyp := reflect.TypeOf(obj.Interface()) + if dyntyp.Kind() == reflect.Map { + obj.Set(ds.register(id, dyntyp).obj) + } else if dyntyp.Kind() == reflect.Ptr { + ds.push(true /* dereference */, "", nil) + obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) + ds.pop() + } else { + obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) + } + } else { + // We leave obj alone here. That's because if obj + // represents an interface, it may have been embued + // with type information in decodeInterface, and we + // don't want to destroy that information. + } + case *pb.Object_SliceValue: + // It's okay to slice the array here, since the contents will + // still be provided later on. These semantics are a bit + // strange but they are handled in the Map.Barrier properly. + // + // The special semantics of zero ref apply here too. + if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 { + v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem()) + obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity))) + } + case *pb.Object_ArrayValue: + ds.decodeArray(os, obj, x.ArrayValue) + case *pb.Object_StructValue: + ds.decodeStruct(os, obj, x.StructValue) + case *pb.Object_MapValue: + ds.decodeMap(os, obj, x.MapValue) + case *pb.Object_InterfaceValue: + ds.decodeInterface(os, obj, x.InterfaceValue) + case *pb.Object_ByteArrayValue: + copyArray(obj, reflect.ValueOf(x.ByteArrayValue)) + case *pb.Object_Uint16ArrayValue: + // 16-bit slices are serialized as 32-bit slices. + // See object.proto for details. + s := x.Uint16ArrayValue.Values + t := obj.Slice(0, obj.Len()).Interface().([]uint16) + if len(t) != len(s) { + panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) + } + for i := range s { + t[i] = uint16(s[i]) + } + case *pb.Object_Uint32ArrayValue: + copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values)) + case *pb.Object_Uint64ArrayValue: + copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values)) + case *pb.Object_UintptrArrayValue: + copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0)))) + case *pb.Object_Int8ArrayValue: + copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0)))) + case *pb.Object_Int16ArrayValue: + // 16-bit slices are serialized as 32-bit slices. + // See object.proto for details. + s := x.Int16ArrayValue.Values + t := obj.Slice(0, obj.Len()).Interface().([]int16) + if len(t) != len(s) { + panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) + } + for i := range s { + t[i] = int16(s[i]) + } + case *pb.Object_Int32ArrayValue: + copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values)) + case *pb.Object_Int64ArrayValue: + copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values)) + case *pb.Object_BoolArrayValue: + copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values)) + case *pb.Object_Float64ArrayValue: + copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values)) + case *pb.Object_Float32ArrayValue: + copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values)) + default: + // Shoud not happen, not propagated as an error. + panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type())) + } + + ds.stats.Done() + ds.pop() +} + +func copyArray(dest reflect.Value, src reflect.Value) { + if dest.Len() != src.Len() { + panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len())) + } + reflect.Copy(dest, castSlice(src, dest.Type().Elem())) +} + +// Deserialize deserializes the object state. +// +// This function may panic and should be run in safely(). +func (ds *decodeState) Deserialize(obj reflect.Value) { + ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()} + ds.outstanding = 1 // The root object. + + // Decode all objects in the stream. + // + // See above, we never process objects while we have no outstanding + // interests (other than the very first object). + for id := uint64(1); ds.outstanding > 0; id++ { + o, err := ds.readObject() + if err != nil { + panic(err) + } + + os := ds.lookup(id) + if os != nil { + // Decode the object. + ds.from = &os.path + ds.decodeObject(os, os.obj, o, "", nil) + ds.outstanding-- + } else { + // If an object hasn't had interest registered + // previously, we deferred decoding until interest is + // registered. + ds.deferred[id] = o + } + } + + // Check the zero-length header at the end. + length, object, err := ReadHeader(ds.r) + if err != nil { + panic(err) + } + if length != 0 { + panic(fmt.Sprintf("expected zero-length terminal, got %d", length)) + } + if object { + panic("expected non-object terminal") + } + + // Check if we have any deferred objects. + if count := len(ds.deferred); count > 0 { + // Shoud not happen, not propagated as an error. + panic(fmt.Sprintf("still have %d deferred objects", count)) + } + + // Scan and fire all callbacks. + for _, os := range ds.objectsByID { + os.checkComplete(ds.stats) + } + + // Check if we have any remaining dependency cycles. + for _, os := range ds.objectsByID { + if !os.complete() { + // This must be the result of a dependency cycle. + cycle := os.findCycle() + var buf bytes.Buffer + buf.WriteString("dependency cycle: {") + for i, cycleOS := range cycle { + if i > 0 { + buf.WriteString(" => ") + } + buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type())) + } + buf.WriteString("}") + // Panic as an error; propagate to the caller. + panic(errors.New(string(buf.Bytes()))) + } + } +} + +type byteReader struct { + io.Reader +} + +// ReadByte implements io.ByteReader. +func (br byteReader) ReadByte() (byte, error) { + var b [1]byte + n, err := br.Reader.Read(b[:]) + if n > 0 { + return b[0], nil + } else if err != nil { + return 0, err + } else { + return 0, io.ErrUnexpectedEOF + } +} + +// ReadHeader reads an object header. +// +// Each object written to the statefile is prefixed with a header. See +// WriteHeader for more information; these functions are exported to allow +// non-state writes to the file to play nice with debugging tools. +func ReadHeader(r io.Reader) (length uint64, object bool, err error) { + // Read the header. + length, err = binary.ReadUvarint(byteReader{r}) + if err != nil { + return + } + + // Decode whether the object is valid. + object = length&0x1 != 0 + length = length >> 1 + return +} + +// readObject reads an object from the stream. +func (ds *decodeState) readObject() (*pb.Object, error) { + // Read the header. + length, object, err := ReadHeader(ds.r) + if err != nil { + return nil, err + } + if !object { + return nil, fmt.Errorf("invalid object header") + } + + // Read the object. + buf := make([]byte, length) + for done := 0; done < len(buf); { + n, err := ds.r.Read(buf[done:]) + done += n + if n == 0 && err != nil { + return nil, err + } + } + + // Unmarshal. + obj := new(pb.Object) + if err := proto.Unmarshal(buf, obj); err != nil { + return nil, err + } + + return obj, nil +} |