summaryrefslogtreecommitdiffhomepage
path: root/pkg/state/encode.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/state/encode.go')
-rw-r--r--pkg/state/encode.go1025
1 files changed, 698 insertions, 327 deletions
diff --git a/pkg/state/encode.go b/pkg/state/encode.go
index c5118d3a9..92fcad4e9 100644
--- a/pkg/state/encode.go
+++ b/pkg/state/encode.go
@@ -15,437 +15,797 @@
package state
import (
- "container/list"
"context"
- "encoding/binary"
- "fmt"
- "io"
"reflect"
- "sort"
- "github.com/golang/protobuf/proto"
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
+ "gvisor.dev/gvisor/pkg/state/wire"
)
-// queuedObject is an object queued for encoding.
-type queuedObject struct {
- id uint64
- obj reflect.Value
- path recoverable
+// objectEncodeState the type and identity of an object occupying a memory
+// address range. This is the value type for addrSet, and the intrusive entry
+// for the pending and deferred lists.
+type objectEncodeState struct {
+ // id is the assigned ID for this object.
+ id objectID
+
+ // obj is the object value. Note that this may be replaced if we
+ // encounter an object that contains this object. When this happens (in
+ // resolve), we will update existing references approprately, below,
+ // and defer a re-encoding of the object.
+ obj reflect.Value
+
+ // encoded is the encoded value of this object. Note that this may not
+ // be up to date if this object is still in the deferred list.
+ encoded wire.Object
+
+ // how indicates whether this object should be encoded as a value. This
+ // is used only for deferred encoding.
+ how encodeStrategy
+
+ // refs are the list of reference objects used by other objects
+ // referring to this object. When the object is updated, these
+ // references may be updated directly and automatically.
+ refs []*wire.Ref
+
+ pendingEntry
+ deferredEntry
}
// encodeState is state used for encoding.
//
-// The encoding process is a breadth-first traversal of the object graph. The
-// inherent races and dependencies are much simpler than the decode case.
+// The encoding process constructs a representation of the in-memory graph of
+// objects before a single object is serialized. This is done to ensure that
+// all references can be fully disambiguated. See resolve for more details.
type encodeState struct {
// ctx is the encode context.
ctx context.Context
- // lastID is the last object ID.
- //
- // See idsByObject for context. Because of the special zero encoding
- // used for reference values, the first ID must be 1.
- lastID uint64
+ // w is the output stream.
+ w wire.Writer
- // idsByObject is a set of objects, indexed via:
- //
- // reflect.ValueOf(x).UnsafeAddr
- //
- // This provides IDs for objects.
- idsByObject map[uintptr]uint64
+ // types is the type database.
+ types typeEncodeDatabase
+
+ // lastID is the last allocated object ID.
+ lastID objectID
- // values stores values that span the addresses.
+ // values tracks the address ranges occupied by objects, along with the
+ // types of these objects. This is used to locate pointer targets,
+ // including pointers to fields within another type.
//
- // addrSet is a a generated type which efficiently stores ranges of
- // addresses. When encoding pointers, these ranges are filled in and
- // used to check for overlapping or conflicting pointers. This would
- // indicate a pointer to an field, or a non-type safe value, neither of
- // which are currently decodable.
+ // Multiple objects may overlap in memory iff the larger object fully
+ // contains the smaller one, and the type of the smaller object matches
+ // a field or array element's type at the appropriate offset. An
+ // arbitrary number of objects may be nested in this manner.
//
- // See the usage of values below for more context.
+ // Note that this does not track zero-sized objects, those are tracked
+ // by zeroValues below.
values addrSet
- // w is the output stream.
- w io.Writer
+ // zeroValues tracks zero-sized objects.
+ zeroValues map[reflect.Type]*objectEncodeState
- // pending is the list of objects to be serialized.
- //
- // This is a set of queuedObjects.
- pending list.List
+ // deferred is the list of objects to be encoded.
+ deferred deferredList
- // done is the a list of finished objects.
- //
- // This is kept to prevent garbage collection and address reuse.
- done list.List
+ // pendingTypes is the list of types to be serialized. Serialization
+ // will occur when all objects have been encoded, but before pending is
+ // serialized.
+ pendingTypes []wire.Type
- // stats is the passed stats object.
- stats *Stats
+ // pending is the list of objects to be serialized. Serialization does
+ // not actually occur until the full object graph is computed.
+ pending pendingList
- // recoverable is the panic recover facility.
- recoverable
+ // stats tracks time data.
+ stats Stats
}
-// register looks up an ID, registering if necessary.
+// isSameSizeParent returns true if child is a field value or element within
+// parent. Only a struct or array can have a child value.
+//
+// isSameSizeParent deals with objects like this:
+//
+// struct child {
+// // fields..
+// }
//
-// If the object was not previously registered, it is enqueued to be serialized.
-// See the documentation for idsByObject for more information.
-func (es *encodeState) register(obj reflect.Value) uint64 {
- // It is not legal to call register for any non-pointer objects (see
- // below), so we panic with a recoverable error if this is a mismatch.
- if obj.Kind() != reflect.Ptr && obj.Kind() != reflect.Map {
- panic(fmt.Errorf("non-pointer %#v registered", obj.Interface()))
+// struct parent {
+// c child
+// }
+//
+// var p parent
+// record(&p.c)
+//
+// Here, &p and &p.c occupy the exact same address range.
+//
+// Or like this:
+//
+// struct child {
+// // fields
+// }
+//
+// var arr [1]parent
+// record(&arr[0])
+//
+// Similarly, &arr[0] and &arr[0].c have the exact same address range.
+//
+// Precondition: parent and child must occupy the same memory.
+func isSameSizeParent(parent reflect.Value, childType reflect.Type) bool {
+ switch parent.Kind() {
+ case reflect.Struct:
+ for i := 0; i < parent.NumField(); i++ {
+ field := parent.Field(i)
+ if field.Type() == childType {
+ return true
+ }
+ // Recurse through any intermediate types.
+ if isSameSizeParent(field, childType) {
+ return true
+ }
+ // Does it make sense to keep going if the first field
+ // doesn't match? Yes, because there might be an
+ // arbitrary number of zero-sized fields before we get
+ // a match, and childType itself can be zero-sized.
+ }
+ return false
+ case reflect.Array:
+ // The only case where an array with more than one elements can
+ // return true is if childType is zero-sized. In such cases,
+ // it's ambiguous which element contains the match since a
+ // zero-sized child object fully fits in any of the zero-sized
+ // elements in an array... However since all elements are of
+ // the same type, we only need to check one element.
+ //
+ // For non-zero-sized childTypes, parent.Len() must be 1, but a
+ // combination of the precondition and an implicit comparison
+ // between the array element size and childType ensures this.
+ return parent.Len() > 0 && isSameSizeParent(parent.Index(0), childType)
+ default:
+ return false
}
+}
- addr := obj.Pointer()
- if obj.Kind() == reflect.Ptr && obj.Elem().Type().Size() == 0 {
- // For zero-sized objects, we always provide a unique ID.
- // That's because the runtime internally multiplexes pointers
- // to the same address. We can't be certain what the intent is
- // with pointers to zero-sized objects, so we just give them
- // all unique identities.
- } else if id, ok := es.idsByObject[addr]; ok {
- // Already registered.
- return id
- }
-
- // Ensure that the first ID given out is one. See note on lastID. The
- // ID zero is used to indicate nil values.
+// nextID returns the next valid ID.
+func (es *encodeState) nextID() objectID {
es.lastID++
- id := es.lastID
- es.idsByObject[addr] = id
- if obj.Kind() == reflect.Ptr {
- // Dereference and treat as a pointer.
- es.pending.PushBack(queuedObject{id: id, obj: obj.Elem(), path: es.recoverable.copy()})
-
- // Register this object at all addresses.
- typ := obj.Elem().Type()
- if size := typ.Size(); size > 0 {
- r := addrRange{addr, addr + size}
- if !es.values.IsEmptyRange(r) {
- old := es.values.LowerBoundSegment(addr).Value().Interface().(recoverable)
- panic(fmt.Errorf("overlapping objects: [new object] %#v [existing object path] %s", obj.Interface(), old.path()))
+ return objectID(es.lastID)
+}
+
+// dummyAddr points to the dummy zero-sized address.
+var dummyAddr = reflect.ValueOf(new(struct{})).Pointer()
+
+// resolve records the address range occupied by an object.
+func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
+ addr := obj.Pointer()
+
+ // Is this a map pointer? Just record the single address. It is not
+ // possible to take any pointers into the map internals.
+ if obj.Kind() == reflect.Map {
+ if addr == 0 {
+ // Just leave the nil reference alone. This is fine, we
+ // may need to encode as a reference in this way. We
+ // return nil for our objectEncodeState so that anyone
+ // depending on this value knows there's nothing there.
+ return
+ }
+ if seg, _ := es.values.Find(addr); seg.Ok() {
+ // Ensure the map types match.
+ existing := seg.Value()
+ if existing.obj.Type() != obj.Type() {
+ Failf("overlapping map objects at 0x%x: [new object] %#v [existing object type] %s", addr, obj, existing.obj)
}
- es.values.Add(r, reflect.ValueOf(es.recoverable.copy()))
+
+ // No sense recording refs, maps may not be replaced by
+ // covering objects, they are maximal.
+ ref.Root = wire.Uint(existing.id)
+ return
}
+
+ // Record the map.
+ oes := &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ how: encodeMapAsValue,
+ }
+ es.values.Add(addrRange{addr, addr + 1}, oes)
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+
+ // See above: no ref recording.
+ ref.Root = wire.Uint(oes.id)
+ return
+ }
+
+ // If not a map, then the object must be a pointer.
+ if obj.Kind() != reflect.Ptr {
+ Failf("attempt to record non-map and non-pointer object %#v", obj)
+ }
+
+ obj = obj.Elem() // Value from here.
+
+ // Is this a zero-sized type?
+ typ := obj.Type()
+ size := typ.Size()
+ if size == 0 {
+ if addr == dummyAddr {
+ // Zero-sized objects point to a dummy byte within the
+ // runtime. There's no sense recording this in the
+ // address map. We add this to the dedicated
+ // zeroValues.
+ //
+ // Note that zero-sized objects must be *true*
+ // zero-sized objects. They cannot be part of some
+ // larger object. In that case, they are assigned a
+ // 1-byte address at the end of the object.
+ oes, ok := es.zeroValues[typ]
+ if !ok {
+ oes = &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ }
+ es.zeroValues[typ] = oes
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+ }
+
+ // There's also no sense tracking back references. We
+ // know that this is a true zero-sized object, and not
+ // part of a larger container, so it will not change.
+ ref.Root = wire.Uint(oes.id)
+ return
+ }
+ size = 1 // See above.
+ }
+
+ // Calculate the container.
+ end := addr + size
+ r := addrRange{addr, end}
+ if seg, _ := es.values.Find(addr); seg.Ok() {
+ existing := seg.Value()
+ switch {
+ case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type():
+ // The object is a perfect match. Happy path. Avoid the
+ // traversal and just return directly. We don't need to
+ // encode the type information or any dots here.
+ ref.Root = wire.Uint(existing.id)
+ existing.refs = append(existing.refs, ref)
+ return
+
+ case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end):
+ // The previously registered object is larger than
+ // this, no need to update. But we expect some
+ // traversal below.
+
+ case seg.Start() == addr && seg.End() == end:
+ if !isSameSizeParent(obj, existing.obj.Type()) {
+ break // Needs traversal.
+ }
+ fallthrough // Needs update.
+
+ case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end):
+ // Update the object and redo the encoding.
+ old := existing.obj
+ existing.obj = obj
+ es.deferred.Remove(existing)
+ es.deferred.PushBack(existing)
+
+ // The previously registered object is superseded by
+ // this new object. We are guaranteed to not have any
+ // mergeable neighbours in this segment set.
+ if !raceEnabled {
+ seg.SetRangeUnchecked(r)
+ } else {
+ // Add extra paranoid. This will be statically
+ // removed at compile time unless a race build.
+ es.values.Remove(seg)
+ es.values.Add(r, existing)
+ seg = es.values.LowerBoundSegment(addr)
+ }
+
+ // Compute the traversal required & update references.
+ dots := traverse(obj.Type(), old.Type(), addr, seg.Start())
+ wt := es.findType(obj.Type())
+ for _, ref := range existing.refs {
+ ref.Dots = append(ref.Dots, dots...)
+ ref.Type = wt
+ }
+ default:
+ // There is a non-sensical overlap.
+ Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj)
+ }
+
+ // Compute the new reference, record and return it.
+ ref.Root = wire.Uint(existing.id)
+ ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr)
+ ref.Type = es.findType(obj.Type())
+ existing.refs = append(existing.refs, ref)
+ return
+ }
+
+ // The only remaining case is a pointer value that doesn't overlap with
+ // any registered addresses. Create a new entry for it, and start
+ // tracking the first reference we just created.
+ oes := &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ }
+ if !raceEnabled {
+ es.values.AddWithoutMerging(r, oes)
} else {
- // Push back the map itself; when maps are encoded from the
- // top-level, forceMap will be equal to true.
- es.pending.PushBack(queuedObject{id: id, obj: obj, path: es.recoverable.copy()})
+ // Merges should never happen. This is just enabled extra
+ // sanity checks because the Merge function below will panic.
+ es.values.Add(r, oes)
+ }
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+ ref.Root = wire.Uint(oes.id)
+ oes.refs = append(oes.refs, ref)
+}
+
+// traverse searches for a target object within a root object, where the target
+// object is a struct field or array element within root, with potentially
+// multiple intervening types. traverse returns the set of field or element
+// traversals required to reach the target.
+//
+// Note that for efficiency, traverse returns the dots in the reverse order.
+// That is, the first traversal required will be the last element of the list.
+//
+// Precondition: The target object must lie completely within the range defined
+// by [rootAddr, rootAddr + sizeof(rootType)].
+func traverse(rootType, targetType reflect.Type, rootAddr, targetAddr uintptr) []wire.Dot {
+ // Recursion base case: the types actually match.
+ if targetType == rootType && targetAddr == rootAddr {
+ return nil
}
- return id
+ switch rootType.Kind() {
+ case reflect.Struct:
+ offset := targetAddr - rootAddr
+ for i := rootType.NumField(); i > 0; i-- {
+ field := rootType.Field(i - 1)
+ // The first field from the end with an offset that is
+ // smaller than or equal to our address offset is where
+ // the target is located. Traverse from there.
+ if field.Offset <= offset {
+ dots := traverse(field.Type, targetType, rootAddr+field.Offset, targetAddr)
+ fieldName := wire.FieldName(field.Name)
+ return append(dots, &fieldName)
+ }
+ }
+ // Should never happen; the target should be reachable.
+ Failf("no field in root type %v contains target type %v", rootType, targetType)
+
+ case reflect.Array:
+ // Since arrays have homogenous types, all elements have the
+ // same size and we can compute where the target lives. This
+ // does not matter for the purpose of typing, but matters for
+ // the purpose of computing the address of the given index.
+ elemSize := int(rootType.Elem().Size())
+ n := int(targetAddr-rootAddr) / elemSize // Relies on integer division rounding down.
+ if rootType.Len() < n {
+ Failf("traversal target of type %v @%x is beyond the end of the array type %v @%x with %v elements",
+ targetType, targetAddr, rootType, rootAddr, rootType.Len())
+ }
+ dots := traverse(rootType.Elem(), targetType, rootAddr+uintptr(n*elemSize), targetAddr)
+ return append(dots, wire.Index(n))
+
+ default:
+ // For any other type, there's no possibility of aliasing so if
+ // the types didn't match earlier then we have an addresss
+ // collision which shouldn't be possible at this point.
+ Failf("traverse failed for root type %v and target type %v", rootType, targetType)
+ }
+ panic("unreachable")
}
// encodeMap encodes a map.
-func (es *encodeState) encodeMap(obj reflect.Value) *pb.Map {
- var (
- keys []*pb.Object
- values []*pb.Object
- )
+func (es *encodeState) encodeMap(obj reflect.Value, dest *wire.Object) {
+ if obj.IsNil() {
+ // Because there is a difference between a nil map and an empty
+ // map, we need to not decode in the case of a truly nil map.
+ *dest = wire.Nil{}
+ return
+ }
+ l := obj.Len()
+ m := &wire.Map{
+ Keys: make([]wire.Object, l),
+ Values: make([]wire.Object, l),
+ }
+ *dest = m
for i, k := range obj.MapKeys() {
v := obj.MapIndex(k)
- kp := es.encodeObject(k, false, ".(key %d)", i)
- vp := es.encodeObject(v, false, "[%#v]", k.Interface())
- keys = append(keys, kp)
- values = append(values, vp)
+ // Map keys must be encoded using the full value because the
+ // type will be omitted after the first key.
+ es.encodeObject(k, encodeAsValue, &m.Keys[i])
+ es.encodeObject(v, encodeAsValue, &m.Values[i])
}
- return &pb.Map{Keys: keys, Values: values}
+}
+
+// objectEncoder is for encoding structs.
+type objectEncoder struct {
+ // es is encodeState.
+ es *encodeState
+
+ // encoded is the encoded struct.
+ encoded *wire.Struct
+}
+
+// save is called by the public methods on Sink.
+func (oe *objectEncoder) save(slot int, obj reflect.Value) {
+ fieldValue := oe.encoded.Field(slot)
+ oe.es.encodeObject(obj, encodeDefault, fieldValue)
}
// encodeStruct encodes a composite object.
-func (es *encodeState) encodeStruct(obj reflect.Value) *pb.Struct {
- // Invoke the save.
- m := Map{newInternalMap(es, nil, nil)}
- defer internalMapPool.Put(m.internalMap)
+func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) {
+ // Ensure that the obj is addressable. There are two cases when it is
+ // not. First, is when this is dispatched via SaveValue. Second, when
+ // this is a map key as a struct. Either way, we need to make a copy to
+ // obtain an addressable value.
if !obj.CanAddr() {
- // Force it to a * type of the above; this involves a copy.
localObj := reflect.New(obj.Type())
localObj.Elem().Set(obj)
obj = localObj.Elem()
}
- fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
- if ok {
- // Invoke the provided saver.
- fns.invokeSave(obj.Addr(), m)
- } else if obj.NumField() == 0 {
- // Allow unregistered anonymous, empty structs.
- return &pb.Struct{}
- } else {
- // Propagate an error.
- panic(fmt.Errorf("unregistered type %T", obj.Interface()))
- }
-
- // Sort the underlying slice, and check for duplicates. This is done
- // once instead of on each add, because performing this sort once is
- // far more efficient.
- if len(m.data) > 1 {
- sort.Slice(m.data, func(i, j int) bool {
- return m.data[i].name < m.data[j].name
- })
- for i := range m.data {
- if i > 0 && m.data[i-1].name == m.data[i].name {
- panic(fmt.Errorf("duplicate name %s", m.data[i].name))
- }
+
+ // Prepare the value.
+ s := &wire.Struct{}
+ *dest = s
+
+ // Look the type up in the database.
+ te, ok := es.types.Lookup(obj.Type())
+ if te == nil {
+ if obj.NumField() == 0 {
+ // Allow unregistered anonymous, empty structs. This
+ // will just return success without ever invoking the
+ // passed function. This uses the immutable EmptyStruct
+ // variable to prevent an allocation in this case.
+ //
+ // Note that this mechanism does *not* work for
+ // interfaces in general. So you can't dispatch
+ // non-registered empty structs via interfaces because
+ // then they can't be restored.
+ s.Alloc(0)
+ return
}
+ // We need a SaverLoader for struct types.
+ Failf("struct %T does not implement SaverLoader", obj.Interface())
}
-
- // Encode the resulting fields.
- fields := make([]*pb.Field, 0, len(m.data))
- for _, e := range m.data {
- fields = append(fields, &pb.Field{
- Name: e.name,
- Value: e.object,
- })
+ if !ok {
+ // Queue the type to be serialized.
+ es.pendingTypes = append(es.pendingTypes, te.Type)
}
- // Return the encoded object.
- return &pb.Struct{Fields: fields}
+ // Invoke the provided saver.
+ s.TypeID = wire.TypeID(te.ID)
+ s.Alloc(len(te.Fields))
+ oe := objectEncoder{
+ es: es,
+ encoded: s,
+ }
+ es.stats.start(te.ID)
+ defer es.stats.done()
+ if sl, ok := obj.Addr().Interface().(SaverLoader); ok {
+ // Note: may be a registered empty struct which does not
+ // implement the saver/loader interfaces.
+ sl.StateSave(Sink{internal: oe})
+ }
}
// encodeArray encodes an array.
-func (es *encodeState) encodeArray(obj reflect.Value) *pb.Array {
- var (
- contents []*pb.Object
- )
- for i := 0; i < obj.Len(); i++ {
- entry := es.encodeObject(obj.Index(i), false, "[%d]", i)
- contents = append(contents, entry)
- }
- return &pb.Array{Contents: contents}
+func (es *encodeState) encodeArray(obj reflect.Value, dest *wire.Object) {
+ l := obj.Len()
+ a := &wire.Array{
+ Contents: make([]wire.Object, l),
+ }
+ *dest = a
+ for i := 0; i < l; i++ {
+ // We need to encode the full value because arrays are encoded
+ // using the type information from only the first element.
+ es.encodeObject(obj.Index(i), encodeAsValue, &a.Contents[i])
+ }
+}
+
+// findType recursively finds type information.
+func (es *encodeState) findType(typ reflect.Type) wire.TypeSpec {
+ // First: check if this is a proper type. It's possible for pointers,
+ // slices, arrays, maps, etc to all have some different type.
+ te, ok := es.types.Lookup(typ)
+ if te != nil {
+ if !ok {
+ // See encodeStruct.
+ es.pendingTypes = append(es.pendingTypes, te.Type)
+ }
+ return wire.TypeID(te.ID)
+ }
+
+ switch typ.Kind() {
+ case reflect.Ptr:
+ return &wire.TypeSpecPointer{
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Slice:
+ return &wire.TypeSpecSlice{
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Array:
+ return &wire.TypeSpecArray{
+ Count: wire.Uint(typ.Len()),
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Map:
+ return &wire.TypeSpecMap{
+ Key: es.findType(typ.Key()),
+ Value: es.findType(typ.Elem()),
+ }
+ default:
+ // After potentially chasing many pointers, the
+ // ultimate type of the object is not known.
+ Failf("type %q is not known", typ)
+ }
+ panic("unreachable")
}
// encodeInterface encodes an interface.
-//
-// Precondition: the value is not nil.
-func (es *encodeState) encodeInterface(obj reflect.Value) *pb.Interface {
- // Check for the nil interface.
- obj = reflect.ValueOf(obj.Interface())
+func (es *encodeState) encodeInterface(obj reflect.Value, dest *wire.Object) {
+ // Dereference the object.
+ obj = obj.Elem()
if !obj.IsValid() {
- return &pb.Interface{
- Type: "", // left alone in decode.
- Value: &pb.Object{Value: &pb.Object_RefValue{0}},
+ // Special case: the nil object.
+ *dest = &wire.Interface{
+ Type: wire.TypeSpecNil{},
+ Value: wire.Nil{},
}
+ return
}
- // We have an interface value here. How do we save that? We
- // resolve the underlying type and save it as a dispatchable.
- typName, ok := registeredTypes.lookupName(obj.Type())
- if !ok {
- panic(fmt.Errorf("type %s is not registered", obj.Type()))
+
+ // Encode underlying object.
+ i := &wire.Interface{
+ Type: es.findType(obj.Type()),
}
+ *dest = i
+ es.encodeObject(obj, encodeAsValue, &i.Value)
+}
- // Encode the object again.
- return &pb.Interface{
- Type: typName,
- Value: es.encodeObject(obj, false, ".(%s)", typName),
+// isPrimitive returns true if this is a primitive object, or a composite
+// object composed entirely of primitives.
+func isPrimitiveZero(typ reflect.Type) bool {
+ switch typ.Kind() {
+ case reflect.Ptr:
+ // Pointers are always treated as primitive types because we
+ // won't encode directly from here. Returning true here won't
+ // prevent the object from being encoded correctly.
+ return true
+ case reflect.Bool:
+ return true
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return true
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ return true
+ case reflect.Float32, reflect.Float64:
+ return true
+ case reflect.Complex64, reflect.Complex128:
+ return true
+ case reflect.String:
+ return true
+ case reflect.Slice:
+ // The slice itself a primitive, but not necessarily the array
+ // that points to. This is similar to a pointer.
+ return true
+ case reflect.Array:
+ // We cannot treat an array as a primitive, because it may be
+ // composed of structures or other things with side-effects.
+ return isPrimitiveZero(typ.Elem())
+ case reflect.Interface:
+ // Since we now that this type is the zero type, the interface
+ // value must be zero. Therefore this is primitive.
+ return true
+ case reflect.Struct:
+ return false
+ case reflect.Map:
+ // The isPrimitiveZero function is called only on zero-types to
+ // see if it's safe to serialize. Since a zero map has no
+ // elements, it is safe to treat as a primitive.
+ return true
+ default:
+ Failf("unknown type %q", typ.Name())
}
+ panic("unreachable")
}
-// encodeObject encodes an object.
-//
-// If mapAsValue is true, then a map will be encoded directly.
-func (es *encodeState) encodeObject(obj reflect.Value, mapAsValue bool, format string, param interface{}) (object *pb.Object) {
- es.push(false, format, param)
- es.stats.Add(obj)
- es.stats.Start(obj)
+// encodeStrategy is the strategy used for encodeObject.
+type encodeStrategy int
+const (
+ // encodeDefault means types are encoded normally as references.
+ encodeDefault encodeStrategy = iota
+
+ // encodeAsValue means that types will never take short-circuited and
+ // will always be encoded as a normal value.
+ encodeAsValue
+
+ // encodeMapAsValue means that even maps will be fully encoded.
+ encodeMapAsValue
+)
+
+// encodeObject encodes an object.
+func (es *encodeState) encodeObject(obj reflect.Value, how encodeStrategy, dest *wire.Object) {
+ if how == encodeDefault && isPrimitiveZero(obj.Type()) && obj.IsZero() {
+ *dest = wire.Nil{}
+ return
+ }
switch obj.Kind() {
+ case reflect.Ptr: // Fast path: first.
+ r := new(wire.Ref)
+ *dest = r
+ if obj.IsNil() {
+ // May be in an array or elsewhere such that a value is
+ // required. So we encode as a reference to the zero
+ // object, which does not exist. Note that this has to
+ // be handled correctly in the decode path as well.
+ return
+ }
+ es.resolve(obj, r)
case reflect.Bool:
- object = &pb.Object{Value: &pb.Object_BoolValue{obj.Bool()}}
+ *dest = wire.Bool(obj.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- object = &pb.Object{Value: &pb.Object_Int64Value{obj.Int()}}
+ *dest = wire.Int(obj.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
- object = &pb.Object{Value: &pb.Object_Uint64Value{obj.Uint()}}
- case reflect.Float32, reflect.Float64:
- object = &pb.Object{Value: &pb.Object_DoubleValue{obj.Float()}}
+ *dest = wire.Uint(obj.Uint())
+ case reflect.Float32:
+ *dest = wire.Float32(obj.Float())
+ case reflect.Float64:
+ *dest = wire.Float64(obj.Float())
+ case reflect.Complex64:
+ c := wire.Complex64(obj.Complex())
+ *dest = &c // Needs alloc.
+ case reflect.Complex128:
+ c := wire.Complex128(obj.Complex())
+ *dest = &c // Needs alloc.
+ case reflect.String:
+ s := wire.String(obj.String())
+ *dest = &s // Needs alloc.
case reflect.Array:
- switch obj.Type().Elem().Kind() {
- case reflect.Uint8:
- object = &pb.Object{Value: &pb.Object_ByteArrayValue{pbSlice(obj).Interface().([]byte)}}
- case reflect.Uint16:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := pbSlice(obj).Interface().([]uint16)
- t := make([]uint32, len(s))
- for i := range s {
- t[i] = uint32(s[i])
- }
- object = &pb.Object{Value: &pb.Object_Uint16ArrayValue{&pb.Uint16S{Values: t}}}
- case reflect.Uint32:
- object = &pb.Object{Value: &pb.Object_Uint32ArrayValue{&pb.Uint32S{Values: pbSlice(obj).Interface().([]uint32)}}}
- case reflect.Uint64:
- object = &pb.Object{Value: &pb.Object_Uint64ArrayValue{&pb.Uint64S{Values: pbSlice(obj).Interface().([]uint64)}}}
- case reflect.Uintptr:
- object = &pb.Object{Value: &pb.Object_UintptrArrayValue{&pb.Uintptrs{Values: pbSlice(obj).Interface().([]uint64)}}}
- case reflect.Int8:
- object = &pb.Object{Value: &pb.Object_Int8ArrayValue{&pb.Int8S{Values: pbSlice(obj).Interface().([]byte)}}}
- case reflect.Int16:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := pbSlice(obj).Interface().([]int16)
- t := make([]int32, len(s))
- for i := range s {
- t[i] = int32(s[i])
- }
- object = &pb.Object{Value: &pb.Object_Int16ArrayValue{&pb.Int16S{Values: t}}}
- case reflect.Int32:
- object = &pb.Object{Value: &pb.Object_Int32ArrayValue{&pb.Int32S{Values: pbSlice(obj).Interface().([]int32)}}}
- case reflect.Int64:
- object = &pb.Object{Value: &pb.Object_Int64ArrayValue{&pb.Int64S{Values: pbSlice(obj).Interface().([]int64)}}}
- case reflect.Bool:
- object = &pb.Object{Value: &pb.Object_BoolArrayValue{&pb.Bools{Values: pbSlice(obj).Interface().([]bool)}}}
- case reflect.Float32:
- object = &pb.Object{Value: &pb.Object_Float32ArrayValue{&pb.Float32S{Values: pbSlice(obj).Interface().([]float32)}}}
- case reflect.Float64:
- object = &pb.Object{Value: &pb.Object_Float64ArrayValue{&pb.Float64S{Values: pbSlice(obj).Interface().([]float64)}}}
- default:
- object = &pb.Object{Value: &pb.Object_ArrayValue{es.encodeArray(obj)}}
- }
+ es.encodeArray(obj, dest)
case reflect.Slice:
- if obj.IsNil() || obj.Cap() == 0 {
- // Handled specially in decode; store as nil value.
- object = &pb.Object{Value: &pb.Object_RefValue{0}}
- } else {
- // Serialize a slice as the array plus length and capacity.
- object = &pb.Object{Value: &pb.Object_SliceValue{&pb.Slice{
- Capacity: uint32(obj.Cap()),
- Length: uint32(obj.Len()),
- RefValue: es.register(arrayFromSlice(obj)),
- }}}
+ s := &wire.Slice{
+ Capacity: wire.Uint(obj.Cap()),
+ Length: wire.Uint(obj.Len()),
}
- case reflect.String:
- object = &pb.Object{Value: &pb.Object_StringValue{[]byte(obj.String())}}
- case reflect.Ptr:
+ *dest = s
+ // Note that we do need to provide a wire.Slice type here as
+ // how is not encodeDefault. If this were the case, then it
+ // would have been caught by the IsZero check above and we
+ // would have just used wire.Nil{}.
if obj.IsNil() {
- // Handled specially in decode; store as a nil value.
- object = &pb.Object{Value: &pb.Object_RefValue{0}}
- } else {
- es.push(true /* dereference */, "", nil)
- object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}}
- es.pop()
+ return
}
+ // Slices need pointer resolution.
+ es.resolve(arrayFromSlice(obj), &s.Ref)
case reflect.Interface:
- // We don't check for IsNil here, as we want to encode type
- // information. The case of the empty interface (no type, no
- // value) is handled by encodeInteface.
- object = &pb.Object{Value: &pb.Object_InterfaceValue{es.encodeInterface(obj)}}
+ es.encodeInterface(obj, dest)
case reflect.Struct:
- object = &pb.Object{Value: &pb.Object_StructValue{es.encodeStruct(obj)}}
+ es.encodeStruct(obj, dest)
case reflect.Map:
- if obj.IsNil() {
- // Handled specially in decode; store as a nil value.
- object = &pb.Object{Value: &pb.Object_RefValue{0}}
- } else if mapAsValue {
- // Encode the map directly.
- object = &pb.Object{Value: &pb.Object_MapValue{es.encodeMap(obj)}}
- } else {
- // Encode a reference to the map.
- //
- // Remove the map object count here to avoid double
- // counting, as this object will be counted again when
- // it gets processed later. We do not add a reference
- // count as the reference is artificial.
- es.stats.Remove(obj)
- object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}}
+ if how == encodeMapAsValue {
+ es.encodeMap(obj, dest)
+ return
}
+ r := new(wire.Ref)
+ *dest = r
+ es.resolve(obj, r)
default:
- panic(fmt.Errorf("unknown primitive %#v", obj.Interface()))
+ Failf("unknown object %#v", obj.Interface())
+ panic("unreachable")
}
-
- es.stats.Done()
- es.pop()
- return
}
-// Serialize serializes the object state.
-//
-// This function may panic and should be run in safely().
-func (es *encodeState) Serialize(obj reflect.Value) {
- es.register(obj.Addr())
-
- // Pop off the list until we're done.
- for es.pending.Len() > 0 {
- e := es.pending.Front()
-
- // Extract the queued object.
- qo := e.Value.(queuedObject)
- es.stats.Start(qo.obj)
+// Save serializes the object graph rooted at obj.
+func (es *encodeState) Save(obj reflect.Value) {
+ es.stats.init()
+ defer es.stats.fini(func(id typeID) string {
+ return es.pendingTypes[id-1].Name
+ })
+
+ // Resolve the first object, which should queue a pile of additional
+ // objects on the pending list. All queued objects should be fully
+ // resolved, and we should be able to serialize after this call.
+ var root wire.Ref
+ es.resolve(obj.Addr(), &root)
+
+ // Encode the graph.
+ var oes *objectEncodeState
+ if err := safely(func() {
+ for oes = es.deferred.Front(); oes != nil; oes = es.deferred.Front() {
+ // Remove and encode the object. Note that as a result
+ // of this encoding, the object may be enqueued on the
+ // deferred list yet again. That's expected, and why it
+ // is removed first.
+ es.deferred.Remove(oes)
+ es.encodeObject(oes.obj, oes.how, &oes.encoded)
+ }
+ }); err != nil {
+ // Include the object in the error message.
+ Failf("encoding error at object %#v: %w", oes.obj.Interface(), err)
+ }
- es.pending.Remove(e)
+ // Check that items are pending.
+ if es.pending.Front() == nil {
+ Failf("pending is empty?")
+ }
- es.from = &qo.path
- o := es.encodeObject(qo.obj, true, "", nil)
+ // Write the header with the number of objects. Note that there is no
+ // way that es.lastID could conflict with objectID, which would
+ // indicate that an impossibly large encoding.
+ if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil {
+ Failf("error writing header: %w", err)
+ }
- // Emit to our output stream.
- if err := es.writeObject(qo.id, o); err != nil {
- panic(err)
+ // Serialize all pending types and pending objects. Note that we don't
+ // bother removing from this list as we walk it because that just
+ // wastes time. It will not change after this point.
+ var id objectID
+ if err := safely(func() {
+ for _, wt := range es.pendingTypes {
+ // Encode the type.
+ wire.Save(es.w, &wt)
}
+ for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() {
+ id++ // First object is 1.
+ if oes.id != id {
+ Failf("expected id %d, got %d", id, oes.id)
+ }
- // Mark as done.
- es.done.PushBack(e)
- es.stats.Done()
+ // Marshall the object.
+ wire.Save(es.w, oes.encoded)
+ }
+ }); err != nil {
+ // Include the object and the error.
+ Failf("error serializing object %#v: %w", oes.encoded, err)
}
- // Write a zero-length terminal at the end; this is a sanity check
- // applied at decode time as well (see decode.go).
- if err := WriteHeader(es.w, 0, false); err != nil {
- panic(err)
+ // Check what we wrote.
+ if id != es.lastID {
+ Failf("expected %d objects, wrote %d", es.lastID, id)
}
}
+// objectFlag indicates that the length is a # of objects, rather than a raw
+// byte length. When this is set on a length header in the stream, it may be
+// decoded appropriately.
+const objectFlag uint64 = 1 << 63
+
// WriteHeader writes a header.
//
// Each object written to the statefile should be prefixed with a header. In
// order to generate statefiles that play nicely with debugging tools, raw
// writes should be prefixed with a header with object set to false and the
// appropriate length. This will allow tools to skip these regions.
-func WriteHeader(w io.Writer, length uint64, object bool) error {
- // The lowest-order bit encodes whether this is a valid object. This is
- // a purely internal convention, but allows the object flag to be
- // returned from ReadHeader.
- length = length << 1
+func WriteHeader(w wire.Writer, length uint64, object bool) error {
+ // Sanity check the length.
+ if length&objectFlag != 0 {
+ Failf("impossibly huge length: %d", length)
+ }
if object {
- length |= 0x1
+ length |= objectFlag
}
// Write a header.
- var hdr [32]byte
- encodedLen := binary.PutUvarint(hdr[:], length)
- for done := 0; done < encodedLen; {
- n, err := w.Write(hdr[done:encodedLen])
- done += n
- if n == 0 && err != nil {
- return err
- }
- }
-
- return nil
+ return safely(func() {
+ wire.SaveUint(w, length)
+ })
}
-// writeObject writes an object to the stream.
-func (es *encodeState) writeObject(id uint64, obj *pb.Object) error {
- // Marshal the proto.
- buf, err := proto.Marshal(obj)
- if err != nil {
- return err
- }
+// pendingMapper is for the pending list.
+type pendingMapper struct{}
- // Write the object header.
- if err := WriteHeader(es.w, uint64(len(buf)), true); err != nil {
- return err
- }
+func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry }
- // Write the object.
- for done := 0; done < len(buf); {
- n, err := es.w.Write(buf[done:])
- done += n
- if n == 0 && err != nil {
- return err
- }
- }
+// deferredMapper is for the deferred list.
+type deferredMapper struct{}
- return nil
-}
+func (deferredMapper) linkerFor(oes *objectEncodeState) *deferredEntry { return &oes.deferredEntry }
// addrSetFunctions is used by addrSet.
type addrSetFunctions struct{}
@@ -458,13 +818,24 @@ func (addrSetFunctions) MaxKey() uintptr {
return ^uintptr(0)
}
-func (addrSetFunctions) ClearValue(val *reflect.Value) {
+func (addrSetFunctions) ClearValue(val **objectEncodeState) {
+ *val = nil
}
-func (addrSetFunctions) Merge(_ addrRange, val1 reflect.Value, _ addrRange, val2 reflect.Value) (reflect.Value, bool) {
- return val1, val1 == val2
+func (addrSetFunctions) Merge(r1 addrRange, val1 *objectEncodeState, r2 addrRange, val2 *objectEncodeState) (*objectEncodeState, bool) {
+ if val1.obj == val2.obj {
+ // This, should never happen. It would indicate that the same
+ // object exists in two non-contiguous address ranges. Note
+ // that this assertion can only be triggered if the race
+ // detector is enabled.
+ Failf("unexpected merge in addrSet @ %v and %v: %#v and %#v", r1, r2, val1.obj, val2.obj)
+ }
+ // Reject the merge.
+ return val1, false
}
-func (addrSetFunctions) Split(_ addrRange, val reflect.Value, _ uintptr) (reflect.Value, reflect.Value) {
- return val, val
+func (addrSetFunctions) Split(r addrRange, val *objectEncodeState, _ uintptr) (*objectEncodeState, *objectEncodeState) {
+ // A split should never happen: we don't remove ranges.
+ Failf("unexpected split in addrSet @ %v: %#v", r, val.obj)
+ panic("unreachable")
}