summaryrefslogtreecommitdiffhomepage
path: root/pkg/state/decode.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/state/decode.go')
-rw-r--r--pkg/state/decode.go594
1 files changed, 594 insertions, 0 deletions
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
new file mode 100644
index 000000000..05758495b
--- /dev/null
+++ b/pkg/state/decode.go
@@ -0,0 +1,594 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "reflect"
+ "sort"
+
+ "github.com/golang/protobuf/proto"
+ pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto"
+)
+
+// objectState represents an object that may be in the process of being
+// decoded. Specifically, it represents either a decoded object, or an an
+// interest in a future object that will be decoded. When that interest is
+// registered (via register), the storage for the object will be created, but
+// it will not be decoded until the object is encountered in the stream.
+type objectState struct {
+ // id is the id for this object.
+ //
+ // If this field is zero, then this is an anonymous (unregistered,
+ // non-reference primitive) object. This is immutable.
+ id uint64
+
+ // obj is the object. This may or may not be valid yet, depending on
+ // whether complete returns true. However, regardless of whether the
+ // object is valid, obj contains a final storage location for the
+ // object. This is immutable.
+ //
+ // Note that this must be addressable (obj.Addr() must not panic).
+ //
+ // The obj passed to the decode methods below will equal this obj only
+ // in the case of decoding the top-level object. However, the passed
+ // obj may represent individual fields, elements of a slice, etc. that
+ // are effectively embedded within the reflect.Value below but with
+ // distinct types.
+ obj reflect.Value
+
+ // blockedBy is the number of dependencies this object has.
+ blockedBy int
+
+ // blocking is a list of the objects blocked by this one.
+ blocking []*objectState
+
+ // callbacks is a set of callbacks to execute on load.
+ callbacks []func()
+
+ // path is the decoding path to the object.
+ path recoverable
+}
+
+// complete indicates the object is complete.
+func (os *objectState) complete() bool {
+ return os.blockedBy == 0 && len(os.callbacks) == 0
+}
+
+// checkComplete checks for completion. If the object is complete, pending
+// callbacks will be executed and checkComplete will be called on downstream
+// objects (those depending on this one).
+func (os *objectState) checkComplete(stats *Stats) {
+ if os.blockedBy > 0 {
+ return
+ }
+
+ // Fire all callbacks.
+ for _, fn := range os.callbacks {
+ stats.Start(os.obj)
+ fn()
+ stats.Done()
+ }
+ os.callbacks = nil
+
+ // Clear all blocked objects.
+ for _, other := range os.blocking {
+ other.blockedBy--
+ other.checkComplete(stats)
+ }
+ os.blocking = nil
+}
+
+// waitFor queues a dependency on the given object.
+func (os *objectState) waitFor(other *objectState, callback func()) {
+ os.blockedBy++
+ other.blocking = append(other.blocking, os)
+ if callback != nil {
+ other.callbacks = append(other.callbacks, callback)
+ }
+}
+
+// findCycleFor returns when the given object is found in the blocking set.
+func (os *objectState) findCycleFor(target *objectState) []*objectState {
+ for _, other := range os.blocking {
+ if other == target {
+ return []*objectState{target}
+ } else if childList := other.findCycleFor(target); childList != nil {
+ return append(childList, other)
+ }
+ }
+ return nil
+}
+
+// findCycle finds a dependency cycle.
+func (os *objectState) findCycle() []*objectState {
+ return append(os.findCycleFor(os), os)
+}
+
+// decodeState is a graph of objects in the process of being decoded.
+//
+// The decode process involves loading the breadth-first graph generated by
+// encode. This graph is read in it's entirety, ensuring that all object
+// storage is complete.
+//
+// As the graph is being serialized, a set of completion callbacks are
+// executed. These completion callbacks should form a set of acyclic subgraphs
+// over the original one. After decoding is complete, the objects are scanned
+// to ensure that all callbacks are executed, otherwise the callback graph was
+// not acyclic.
+type decodeState struct {
+ // objectByID is the set of objects in progress.
+ objectsByID map[uint64]*objectState
+
+ // deferred are objects that have been read, by no interest has been
+ // registered yet. These will be decoded once interest in registered.
+ deferred map[uint64]*pb.Object
+
+ // outstanding is the number of outstanding objects.
+ outstanding uint32
+
+ // r is the input stream.
+ r io.Reader
+
+ // stats is the passed stats object.
+ stats *Stats
+
+ // recoverable is the panic recover facility.
+ recoverable
+}
+
+// lookup looks up an object in decodeState or returns nil if no such object
+// has been previously registered.
+func (ds *decodeState) lookup(id uint64) *objectState {
+ return ds.objectsByID[id]
+}
+
+// wait registers a dependency on an object.
+//
+// As a special case, we always allow _useable_ references back to the first
+// decoding object because it may have fields that are already decoded. We also
+// allow trivial self reference, since they can be handled internally.
+func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) {
+ switch id {
+ case 0:
+ // Nil pointer; nothing to wait for.
+ fallthrough
+ case waiter.id:
+ // Trivial self reference.
+ fallthrough
+ case 1:
+ // Root object; see above.
+ if callback != nil {
+ callback()
+ }
+ return
+ }
+
+ // No nil can be returned here.
+ waiter.waitFor(ds.lookup(id), callback)
+}
+
+// waitObject notes a blocking relationship.
+func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) {
+ if rv, ok := p.Value.(*pb.Object_RefValue); ok {
+ // Refs can encode pointers and maps.
+ ds.wait(os, rv.RefValue, callback)
+ } else if sv, ok := p.Value.(*pb.Object_SliceValue); ok {
+ // See decodeObject; we need to wait for the array (if non-nil).
+ ds.wait(os, sv.SliceValue.RefValue, callback)
+ } else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok {
+ // It's an interface (wait recurisvely).
+ ds.waitObject(os, iv.InterfaceValue.Value, callback)
+ } else if callback != nil {
+ // Nothing to wait for: execute the callback immediately.
+ callback()
+ }
+}
+
+// register registers a decode with a type.
+//
+// This type is only used to instantiate a new object if it has not been
+// registered previously.
+func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState {
+ os, ok := ds.objectsByID[id]
+ if ok {
+ return os
+ }
+
+ // Record in the object index.
+ if typ.Kind() == reflect.Map {
+ os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()}
+ } else {
+ os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()}
+ }
+ ds.objectsByID[id] = os
+
+ if o, ok := ds.deferred[id]; ok {
+ // There is a deferred object.
+ delete(ds.deferred, id) // Free memory.
+ ds.decodeObject(os, os.obj, o, "", nil)
+ } else {
+ // There is no deferred object.
+ ds.outstanding++
+ }
+
+ return os
+}
+
+// decodeStruct decodes a struct value.
+func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) {
+ // Set the fields.
+ m := Map{newInternalMap(nil, ds, os)}
+ defer internalMapPool.Put(m.internalMap)
+ for _, field := range s.Fields {
+ m.data = append(m.data, entry{
+ name: field.Name,
+ object: field.Value,
+ })
+ }
+
+ // Sort the fields for efficient searching.
+ //
+ // Technically, these should already appear in sorted order in the
+ // state ordering, so this cost is effectively a single scan to ensure
+ // that the order is correct.
+ if len(m.data) > 1 {
+ sort.Slice(m.data, func(i, j int) bool {
+ return m.data[i].name < m.data[j].name
+ })
+ }
+
+ // Invoke the load; this will recursively decode other objects.
+ fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
+ if ok {
+ // Invoke the loader.
+ fns.invokeLoad(obj.Addr(), m)
+ } else if obj.NumField() == 0 {
+ // Allow anonymous empty structs.
+ return
+ } else {
+ // Propagate an error.
+ panic(fmt.Errorf("unregistered type %s", obj.Type()))
+ }
+}
+
+// decodeMap decodes a map value.
+func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) {
+ if obj.IsNil() {
+ obj.Set(reflect.MakeMap(obj.Type()))
+ }
+ for i := 0; i < len(m.Keys); i++ {
+ // Decode the objects.
+ kv := reflect.New(obj.Type().Key()).Elem()
+ vv := reflect.New(obj.Type().Elem()).Elem()
+ ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i)
+ ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface())
+ ds.waitObject(os, m.Keys[i], nil)
+ ds.waitObject(os, m.Values[i], nil)
+
+ // Set in the map.
+ obj.SetMapIndex(kv, vv)
+ }
+}
+
+// decodeArray decodes an array value.
+func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) {
+ if len(a.Contents) != obj.Len() {
+ panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents)))
+ }
+ // Decode the contents into the array.
+ for i := 0; i < len(a.Contents); i++ {
+ ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i)
+ ds.waitObject(os, a.Contents[i], nil)
+ }
+}
+
+// decodeInterface decodes an interface value.
+func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) {
+ // Is this a nil value?
+ if i.Type == "" {
+ return // Just leave obj alone.
+ }
+
+ // Get the dispatchable type. This may not be used if the given
+ // reference has already been resolved, but if not we need to know the
+ // type to create.
+ t, ok := registeredTypes.lookupType(i.Type)
+ if !ok {
+ panic(fmt.Errorf("no valid type for %q", i.Type))
+ }
+
+ if obj.Kind() != reflect.Map {
+ // Set the obj to be the given typed value; this actually sets
+ // obj to be a non-zero value -- namely, it inserts type
+ // information. There's no need to do this for maps.
+ obj.Set(reflect.Zero(t))
+ }
+
+ // Decode the dereferenced element; there is no need to wait here, as
+ // the interface object shares the current object state.
+ ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type)
+}
+
+// decodeObject decodes a object value.
+func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) {
+ ds.push(false, format, param)
+ ds.stats.Start(obj)
+
+ switch x := object.GetValue().(type) {
+ case *pb.Object_BoolValue:
+ obj.SetBool(x.BoolValue)
+ case *pb.Object_StringValue:
+ obj.SetString(x.StringValue)
+ case *pb.Object_Int64Value:
+ obj.SetInt(x.Int64Value)
+ if obj.Int() != x.Int64Value {
+ panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type()))
+ }
+ case *pb.Object_Uint64Value:
+ obj.SetUint(x.Uint64Value)
+ if obj.Uint() != x.Uint64Value {
+ panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type()))
+ }
+ case *pb.Object_DoubleValue:
+ obj.SetFloat(x.DoubleValue)
+ if obj.Float() != x.DoubleValue {
+ panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type()))
+ }
+ case *pb.Object_RefValue:
+ // Resolve the pointer itself, even though the object may not
+ // be decoded yet. You need to use wait() in order to ensure
+ // that is the case. See wait above, and Map.Barrier.
+ if id := x.RefValue; id != 0 {
+ // Decoding the interface should have imparted type
+ // information, so from this point it's safe to resolve
+ // and use this dynamic information for actually
+ // creating the object in register.
+ //
+ // (For non-interfaces this is a no-op).
+ dyntyp := reflect.TypeOf(obj.Interface())
+ if dyntyp.Kind() == reflect.Map {
+ obj.Set(ds.register(id, dyntyp).obj)
+ } else if dyntyp.Kind() == reflect.Ptr {
+ ds.push(true /* dereference */, "", nil)
+ obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
+ ds.pop()
+ } else {
+ obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
+ }
+ } else {
+ // We leave obj alone here. That's because if obj
+ // represents an interface, it may have been embued
+ // with type information in decodeInterface, and we
+ // don't want to destroy that information.
+ }
+ case *pb.Object_SliceValue:
+ // It's okay to slice the array here, since the contents will
+ // still be provided later on. These semantics are a bit
+ // strange but they are handled in the Map.Barrier properly.
+ //
+ // The special semantics of zero ref apply here too.
+ if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 {
+ v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem())
+ obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity)))
+ }
+ case *pb.Object_ArrayValue:
+ ds.decodeArray(os, obj, x.ArrayValue)
+ case *pb.Object_StructValue:
+ ds.decodeStruct(os, obj, x.StructValue)
+ case *pb.Object_MapValue:
+ ds.decodeMap(os, obj, x.MapValue)
+ case *pb.Object_InterfaceValue:
+ ds.decodeInterface(os, obj, x.InterfaceValue)
+ case *pb.Object_ByteArrayValue:
+ copyArray(obj, reflect.ValueOf(x.ByteArrayValue))
+ case *pb.Object_Uint16ArrayValue:
+ // 16-bit slices are serialized as 32-bit slices.
+ // See object.proto for details.
+ s := x.Uint16ArrayValue.Values
+ t := obj.Slice(0, obj.Len()).Interface().([]uint16)
+ if len(t) != len(s) {
+ panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
+ }
+ for i := range s {
+ t[i] = uint16(s[i])
+ }
+ case *pb.Object_Uint32ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values))
+ case *pb.Object_Uint64ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values))
+ case *pb.Object_UintptrArrayValue:
+ copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0))))
+ case *pb.Object_Int8ArrayValue:
+ copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0))))
+ case *pb.Object_Int16ArrayValue:
+ // 16-bit slices are serialized as 32-bit slices.
+ // See object.proto for details.
+ s := x.Int16ArrayValue.Values
+ t := obj.Slice(0, obj.Len()).Interface().([]int16)
+ if len(t) != len(s) {
+ panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
+ }
+ for i := range s {
+ t[i] = int16(s[i])
+ }
+ case *pb.Object_Int32ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values))
+ case *pb.Object_Int64ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values))
+ case *pb.Object_BoolArrayValue:
+ copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values))
+ case *pb.Object_Float64ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values))
+ case *pb.Object_Float32ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values))
+ default:
+ // Shoud not happen, not propagated as an error.
+ panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type()))
+ }
+
+ ds.stats.Done()
+ ds.pop()
+}
+
+func copyArray(dest reflect.Value, src reflect.Value) {
+ if dest.Len() != src.Len() {
+ panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len()))
+ }
+ reflect.Copy(dest, castSlice(src, dest.Type().Elem()))
+}
+
+// Deserialize deserializes the object state.
+//
+// This function may panic and should be run in safely().
+func (ds *decodeState) Deserialize(obj reflect.Value) {
+ ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()}
+ ds.outstanding = 1 // The root object.
+
+ // Decode all objects in the stream.
+ //
+ // See above, we never process objects while we have no outstanding
+ // interests (other than the very first object).
+ for id := uint64(1); ds.outstanding > 0; id++ {
+ o, err := ds.readObject()
+ if err != nil {
+ panic(err)
+ }
+
+ os := ds.lookup(id)
+ if os != nil {
+ // Decode the object.
+ ds.from = &os.path
+ ds.decodeObject(os, os.obj, o, "", nil)
+ ds.outstanding--
+ } else {
+ // If an object hasn't had interest registered
+ // previously, we deferred decoding until interest is
+ // registered.
+ ds.deferred[id] = o
+ }
+ }
+
+ // Check the zero-length header at the end.
+ length, object, err := ReadHeader(ds.r)
+ if err != nil {
+ panic(err)
+ }
+ if length != 0 {
+ panic(fmt.Sprintf("expected zero-length terminal, got %d", length))
+ }
+ if object {
+ panic("expected non-object terminal")
+ }
+
+ // Check if we have any deferred objects.
+ if count := len(ds.deferred); count > 0 {
+ // Shoud not happen, not propagated as an error.
+ panic(fmt.Sprintf("still have %d deferred objects", count))
+ }
+
+ // Scan and fire all callbacks.
+ for _, os := range ds.objectsByID {
+ os.checkComplete(ds.stats)
+ }
+
+ // Check if we have any remaining dependency cycles.
+ for _, os := range ds.objectsByID {
+ if !os.complete() {
+ // This must be the result of a dependency cycle.
+ cycle := os.findCycle()
+ var buf bytes.Buffer
+ buf.WriteString("dependency cycle: {")
+ for i, cycleOS := range cycle {
+ if i > 0 {
+ buf.WriteString(" => ")
+ }
+ buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type()))
+ }
+ buf.WriteString("}")
+ // Panic as an error; propagate to the caller.
+ panic(errors.New(string(buf.Bytes())))
+ }
+ }
+}
+
+type byteReader struct {
+ io.Reader
+}
+
+// ReadByte implements io.ByteReader.
+func (br byteReader) ReadByte() (byte, error) {
+ var b [1]byte
+ n, err := br.Reader.Read(b[:])
+ if n > 0 {
+ return b[0], nil
+ } else if err != nil {
+ return 0, err
+ } else {
+ return 0, io.ErrUnexpectedEOF
+ }
+}
+
+// ReadHeader reads an object header.
+//
+// Each object written to the statefile is prefixed with a header. See
+// WriteHeader for more information; these functions are exported to allow
+// non-state writes to the file to play nice with debugging tools.
+func ReadHeader(r io.Reader) (length uint64, object bool, err error) {
+ // Read the header.
+ length, err = binary.ReadUvarint(byteReader{r})
+ if err != nil {
+ return
+ }
+
+ // Decode whether the object is valid.
+ object = length&0x1 != 0
+ length = length >> 1
+ return
+}
+
+// readObject reads an object from the stream.
+func (ds *decodeState) readObject() (*pb.Object, error) {
+ // Read the header.
+ length, object, err := ReadHeader(ds.r)
+ if err != nil {
+ return nil, err
+ }
+ if !object {
+ return nil, fmt.Errorf("invalid object header")
+ }
+
+ // Read the object.
+ buf := make([]byte, length)
+ for done := 0; done < len(buf); {
+ n, err := ds.r.Read(buf[done:])
+ done += n
+ if n == 0 && err != nil {
+ return nil, err
+ }
+ }
+
+ // Unmarshal.
+ obj := new(pb.Object)
+ if err := proto.Unmarshal(buf, obj); err != nil {
+ return nil, err
+ }
+
+ return obj, nil
+}