summaryrefslogtreecommitdiffhomepage
path: root/pkg/state
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/state')
-rw-r--r--pkg/state/BUILD14
-rw-r--r--pkg/state/decode.go86
-rw-r--r--pkg/state/decode_unsafe.go57
-rw-r--r--pkg/state/encode.go221
-rw-r--r--pkg/state/pretty/pretty.go121
-rw-r--r--pkg/state/state.go10
-rw-r--r--pkg/state/tests/load_test.go8
-rw-r--r--pkg/state/tests/struct.go35
-rw-r--r--pkg/state/tests/struct_test.go34
-rw-r--r--pkg/state/types.go14
10 files changed, 399 insertions, 201 deletions
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
index 089b3bbef..92c51879b 100644
--- a/pkg/state/BUILD
+++ b/pkg/state/BUILD
@@ -4,19 +4,6 @@ load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
go_template_instance(
- name = "pending_list",
- out = "pending_list.go",
- package = "state",
- prefix = "pending",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*objectEncodeState",
- "ElementMapper": "pendingMapper",
- "Linker": "*pendingEntry",
- },
-)
-
-go_template_instance(
name = "deferred_list",
out = "deferred_list.go",
package = "state",
@@ -83,7 +70,6 @@ go_library(
"deferred_list.go",
"encode.go",
"encode_unsafe.go",
- "pending_list.go",
"state.go",
"state_norace.go",
"state_race.go",
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
index c9971cdf6..e519ddeca 100644
--- a/pkg/state/decode.go
+++ b/pkg/state/decode.go
@@ -21,6 +21,7 @@ import (
"math"
"reflect"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/state/wire"
)
@@ -258,7 +259,7 @@ func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, c
// For the purposes of this function, a child object is either a field within a
// struct or an array element, with one such indirection per element in
// path. The returned value may be an unexported field, so it may not be
-// directly assignable. See unsafePointerTo.
+// directly assignable. See decode_unsafe.go.
func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value {
// See wire.Ref.Dots. The path here is specified in reverse order.
for i := len(path) - 1; i >= 0; i-- {
@@ -519,9 +520,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e
// Normal assignment: authoritative only if no dots.
v := ds.register(x, obj.Type().Elem())
- if v.IsValid() {
- obj.Set(unsafePointerTo(v))
- }
+ obj.Set(reflectValueRWAddr(v))
case wire.Bool:
obj.SetBool(bool(x))
case wire.Int:
@@ -559,7 +558,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e
// contents will still be filled in later on.
typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type.
v := ds.register(&x.Ref, typ)
- obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity)))
+ obj.Set(reflectValueRWSlice3(v, 0, int(x.Length), int(x.Capacity)))
case *wire.Array:
ds.decodeArray(ods, obj, x)
case *wire.Struct:
@@ -584,13 +583,15 @@ func (ds *decodeState) Load(obj reflect.Value) {
})
// Create the root object.
- ds.objectsByID = append(ds.objectsByID, &objectDecodeState{
+ rootOds := &objectDecodeState{
id: 1,
obj: obj,
- })
+ }
+ ds.objectsByID = append(ds.objectsByID, rootOds)
+ ds.pending.PushBack(rootOds)
// Read the number of objects.
- lastID, object, err := ReadHeader(ds.r)
+ numObjects, object, err := ReadHeader(ds.r)
if err != nil {
Failf("header error: %w", err)
}
@@ -602,42 +603,44 @@ func (ds *decodeState) Load(obj reflect.Value) {
var (
encoded wire.Object
ods *objectDecodeState
- id = objectID(1)
+ id objectID
tid = typeID(1)
)
if err := safely(func() {
// Decode all objects in the stream.
//
- // Note that the structure of this decoding loop should match
- // the raw decoding loop in printer.go.
- for id <= objectID(lastID) {
- // Unmarshal the object.
+ // Note that the structure of this decoding loop should match the raw
+ // decoding loop in state/pretty/pretty.printer.printStream().
+ for i := uint64(0); i < numObjects; {
+ // Unmarshal either a type object or object ID.
encoded = wire.Load(ds.r)
-
- // Is this a type object? Handle inline.
- if wt, ok := encoded.(*wire.Type); ok {
- ds.types.Register(wt)
+ switch we := encoded.(type) {
+ case *wire.Type:
+ ds.types.Register(we)
tid++
encoded = nil
continue
+ case wire.Uint:
+ id = objectID(we)
+ i++
+ // Unmarshal and resolve the actual object.
+ encoded = wire.Load(ds.r)
+ ods = ds.lookup(id)
+ if ods != nil {
+ // Decode the object.
+ ds.decodeObject(ods, ods.obj, encoded)
+ } else {
+ // If an object hasn't had interest registered
+ // previously or isn't yet valid, we deferred
+ // decoding until interest is registered.
+ ds.deferred[id] = encoded
+ }
+ // For error handling.
+ ods = nil
+ encoded = nil
+ default:
+ Failf("wanted type or object ID, got %#v", encoded)
}
-
- // Actually resolve the object.
- ods = ds.lookup(id)
- if ods != nil {
- // Decode the object.
- ds.decodeObject(ods, ods.obj, encoded)
- } else {
- // If an object hasn't had interest registered
- // previously or isn't yet valid, we deferred
- // decoding until interest is registered.
- ds.deferred[id] = encoded
- }
-
- // For error handling.
- ods = nil
- encoded = nil
- id++
}
}); err != nil {
// Include as much information as we can, taking into account
@@ -645,16 +648,25 @@ func (ds *decodeState) Load(obj reflect.Value) {
if ods != nil {
Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err)
} else if encoded != nil {
- Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err)
+ Failf("error decoding from %#v: %w", encoded, err)
} else {
Failf("general decoding error: %w", err)
}
}
// Check if we have any deferred objects.
+ numDeferred := 0
for id, encoded := range ds.deferred {
- // Shoud never happen, the graph was bogus.
- Failf("still have deferred objects: one is ID %d, %#v", id, encoded)
+ numDeferred++
+ if s, ok := encoded.(*wire.Struct); ok && s.TypeID != 0 {
+ typ := ds.types.LookupType(typeID(s.TypeID))
+ log.Warningf("unused deferred object: ID %d, type %v", id, typ)
+ } else {
+ log.Warningf("unused deferred object: ID %d, %#v", id, encoded)
+ }
+ }
+ if numDeferred != 0 {
+ Failf("still had %d deferred objects", numDeferred)
}
// Scan and fire all callbacks. We iterate over the list of incomplete
diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go
index d048f61a1..f1208e2a2 100644
--- a/pkg/state/decode_unsafe.go
+++ b/pkg/state/decode_unsafe.go
@@ -15,13 +15,62 @@
package state
import (
+ "fmt"
"reflect"
+ "runtime"
"unsafe"
)
-// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on
-// values representing unexported fields. This bypasses visibility, but not
-// type safety.
-func unsafePointerTo(obj reflect.Value) reflect.Value {
+// reflectValueRWAddr is equivalent to obj.Addr(), except that the returned
+// reflect.Value is usable in assignments even if obj was obtained by the use
+// of unexported struct fields.
+//
+// Preconditions: obj.CanAddr().
+func reflectValueRWAddr(obj reflect.Value) reflect.Value {
return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr()))
}
+
+// reflectValueRWSlice3 is equivalent to arr.Slice3(i, j, k), except that the
+// returned reflect.Value is usable in assignments even if obj was obtained by
+// the use of unexported struct fields.
+//
+// Preconditions:
+// * arr.Kind() == reflect.Array.
+// * i, j, k >= 0.
+// * i <= j <= k <= arr.Len().
+func reflectValueRWSlice3(arr reflect.Value, i, j, k int) reflect.Value {
+ if arr.Kind() != reflect.Array {
+ panic(fmt.Sprintf("arr has kind %v, wanted %v", arr.Kind(), reflect.Array))
+ }
+ if i < 0 || j < 0 || k < 0 {
+ panic(fmt.Sprintf("negative subscripts (%d, %d, %d)", i, j, k))
+ }
+ if i > j {
+ panic(fmt.Sprintf("subscript i (%d) > j (%d)", i, j))
+ }
+ if j > k {
+ panic(fmt.Sprintf("subscript j (%d) > k (%d)", j, k))
+ }
+ if k > arr.Len() {
+ panic(fmt.Sprintf("subscript k (%d) > array length (%d)", k, arr.Len()))
+ }
+
+ sliceTyp := reflect.SliceOf(arr.Type().Elem())
+ if i == arr.Len() {
+ // By precondition, i == j == k == arr.Len().
+ return reflect.MakeSlice(sliceTyp, 0, 0)
+ }
+ slh := reflect.SliceHeader{
+ // reflect.Value.CanAddr() == false for arrays, so we need to get the
+ // address from the first element of the array.
+ Data: arr.Index(i).UnsafeAddr(),
+ Len: j - i,
+ Cap: k - i,
+ }
+ slobj := reflect.NewAt(sliceTyp, unsafe.Pointer(&slh)).Elem()
+ // Before slobj is constructed, arr holds the only pointer-typed pointer to
+ // the array since reflect.SliceHeader.Data is a uintptr, so arr must be
+ // kept alive.
+ runtime.KeepAlive(arr)
+ return slobj
+}
diff --git a/pkg/state/encode.go b/pkg/state/encode.go
index 92fcad4e9..560e7c2a3 100644
--- a/pkg/state/encode.go
+++ b/pkg/state/encode.go
@@ -17,13 +17,14 @@ package state
import (
"context"
"reflect"
+ "sort"
"gvisor.dev/gvisor/pkg/state/wire"
)
// objectEncodeState the type and identity of an object occupying a memory
// address range. This is the value type for addrSet, and the intrusive entry
-// for the pending and deferred lists.
+// for the deferred list.
type objectEncodeState struct {
// id is the assigned ID for this object.
id objectID
@@ -47,7 +48,6 @@ type objectEncodeState struct {
// references may be updated directly and automatically.
refs []*wire.Ref
- pendingEntry
deferredEntry
}
@@ -93,9 +93,15 @@ type encodeState struct {
// serialized.
pendingTypes []wire.Type
- // pending is the list of objects to be serialized. Serialization does
+ // pending maps object IDs to objects to be serialized. Serialization does
// not actually occur until the full object graph is computed.
- pending pendingList
+ pending map[objectID]*objectEncodeState
+
+ // encodedStructs maps reflect.Values representing structs to previous
+ // encodings of those structs. This is necessary to avoid duplicate calls
+ // to SaverLoader.StateSave() that may result in multiple calls to
+ // Sink.SaveValue() for a given field, resulting in object duplication.
+ encodedStructs map[reflect.Value]*wire.Struct
// stats tracks time data.
stats Stats
@@ -189,7 +195,8 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
// depending on this value knows there's nothing there.
return
}
- if seg, _ := es.values.Find(addr); seg.Ok() {
+ seg, gap := es.values.Find(addr)
+ if seg.Ok() {
// Ensure the map types match.
existing := seg.Value()
if existing.obj.Type() != obj.Type() {
@@ -203,13 +210,20 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
}
// Record the map.
+ r := addrRange{addr, addr + 1}
oes := &objectEncodeState{
id: es.nextID(),
obj: obj,
how: encodeMapAsValue,
}
- es.values.Add(addrRange{addr, addr + 1}, oes)
- es.pending.PushBack(oes)
+ // Use Insert instead of InsertWithoutMergingUnchecked when race
+ // detection is enabled to get additional sanity-checking from Merge.
+ if !raceEnabled {
+ es.values.InsertWithoutMergingUnchecked(gap, r, oes)
+ } else {
+ es.values.Insert(gap, r, oes)
+ }
+ es.pending[oes.id] = oes
es.deferred.PushBack(oes)
// See above: no ref recording.
@@ -245,7 +259,7 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
obj: obj,
}
es.zeroValues[typ] = oes
- es.pending.PushBack(oes)
+ es.pending[oes.id] = oes
es.deferred.PushBack(oes)
}
@@ -258,86 +272,112 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
size = 1 // See above.
}
- // Calculate the container.
end := addr + size
r := addrRange{addr, end}
- if seg, _ := es.values.Find(addr); seg.Ok() {
+ seg := es.values.LowerBoundSegment(addr)
+ var (
+ oes *objectEncodeState
+ gap addrGapIterator
+ )
+
+ // Does at least one previously-registered object overlap this one?
+ if seg.Ok() && seg.Start() < end {
existing := seg.Value()
- switch {
- case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type():
- // The object is a perfect match. Happy path. Avoid the
- // traversal and just return directly. We don't need to
- // encode the type information or any dots here.
+
+ if seg.Range() == r && typ == existing.obj.Type() {
+ // This exact object is already registered. Avoid the traversal and
+ // just return directly. We don't need to encode the type
+ // information or any dots here.
ref.Root = wire.Uint(existing.id)
existing.refs = append(existing.refs, ref)
return
+ }
- case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end):
- // The previously registered object is larger than
- // this, no need to update. But we expect some
- // traversal below.
+ if seg.Range().IsSupersetOf(r) && (seg.Range() != r || isSameSizeParent(existing.obj, typ)) {
+ // This object is contained within a previously-registered object.
+ // Perform traversal from the container to the new object.
+ ref.Root = wire.Uint(existing.id)
+ ref.Dots = traverse(existing.obj.Type(), typ, seg.Start(), addr)
+ ref.Type = es.findType(existing.obj.Type())
+ existing.refs = append(existing.refs, ref)
+ return
+ }
- case seg.Start() == addr && seg.End() == end:
- if !isSameSizeParent(obj, existing.obj.Type()) {
- break // Needs traversal.
+ // This object contains one or more previously-registered objects.
+ // Remove them and update existing references to use the new one.
+ oes := &objectEncodeState{
+ // Reuse the root ID of the first contained element.
+ id: existing.id,
+ obj: obj,
+ }
+ type elementEncodeState struct {
+ addr uintptr
+ typ reflect.Type
+ refs []*wire.Ref
+ }
+ var (
+ elems []elementEncodeState
+ gap addrGapIterator
+ )
+ for {
+ // Each contained object should be completely contained within
+ // this one.
+ if raceEnabled && !r.IsSupersetOf(seg.Range()) {
+ Failf("containing object %#v does not contain existing object %#v", obj, existing.obj)
}
- fallthrough // Needs update.
-
- case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end):
- // Update the object and redo the encoding.
- old := existing.obj
- existing.obj = obj
+ elems = append(elems, elementEncodeState{
+ addr: seg.Start(),
+ typ: existing.obj.Type(),
+ refs: existing.refs,
+ })
+ delete(es.pending, existing.id)
es.deferred.Remove(existing)
- es.deferred.PushBack(existing)
-
- // The previously registered object is superseded by
- // this new object. We are guaranteed to not have any
- // mergeable neighbours in this segment set.
- if !raceEnabled {
- seg.SetRangeUnchecked(r)
- } else {
- // Add extra paranoid. This will be statically
- // removed at compile time unless a race build.
- es.values.Remove(seg)
- es.values.Add(r, existing)
- seg = es.values.LowerBoundSegment(addr)
+ gap = es.values.Remove(seg)
+ seg = gap.NextSegment()
+ if !seg.Ok() || seg.Start() >= end {
+ break
}
-
- // Compute the traversal required & update references.
- dots := traverse(obj.Type(), old.Type(), addr, seg.Start())
- wt := es.findType(obj.Type())
- for _, ref := range existing.refs {
+ existing = seg.Value()
+ }
+ wt := es.findType(typ)
+ for _, elem := range elems {
+ dots := traverse(typ, elem.typ, addr, elem.addr)
+ for _, ref := range elem.refs {
+ ref.Root = wire.Uint(oes.id)
ref.Dots = append(ref.Dots, dots...)
ref.Type = wt
}
- default:
- // There is a non-sensical overlap.
- Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj)
+ oes.refs = append(oes.refs, elem.refs...)
}
-
- // Compute the new reference, record and return it.
- ref.Root = wire.Uint(existing.id)
- ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr)
- ref.Type = es.findType(obj.Type())
- existing.refs = append(existing.refs, ref)
+ // Finally register the new containing object.
+ if !raceEnabled {
+ es.values.InsertWithoutMergingUnchecked(gap, r, oes)
+ } else {
+ es.values.Insert(gap, r, oes)
+ }
+ es.pending[oes.id] = oes
+ es.deferred.PushBack(oes)
+ ref.Root = wire.Uint(oes.id)
+ oes.refs = append(oes.refs, ref)
return
}
- // The only remaining case is a pointer value that doesn't overlap with
- // any registered addresses. Create a new entry for it, and start
- // tracking the first reference we just created.
- oes := &objectEncodeState{
+ // No existing object overlaps this one. Register a new object.
+ oes = &objectEncodeState{
id: es.nextID(),
obj: obj,
}
+ if seg.Ok() {
+ gap = seg.PrevGap()
+ } else {
+ gap = es.values.LastGap()
+ }
if !raceEnabled {
- es.values.AddWithoutMerging(r, oes)
+ es.values.InsertWithoutMergingUnchecked(gap, r, oes)
} else {
- // Merges should never happen. This is just enabled extra
- // sanity checks because the Merge function below will panic.
- es.values.Add(r, oes)
+ es.values.Insert(gap, r, oes)
}
- es.pending.PushBack(oes)
+ es.pending[oes.id] = oes
es.deferred.PushBack(oes)
ref.Root = wire.Uint(oes.id)
oes.refs = append(oes.refs, ref)
@@ -439,6 +479,14 @@ func (oe *objectEncoder) save(slot int, obj reflect.Value) {
// encodeStruct encodes a composite object.
func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) {
+ if s, ok := es.encodedStructs[obj]; ok {
+ *dest = s
+ return
+ }
+ s := &wire.Struct{}
+ *dest = s
+ es.encodedStructs[obj] = s
+
// Ensure that the obj is addressable. There are two cases when it is
// not. First, is when this is dispatched via SaveValue. Second, when
// this is a map key as a struct. Either way, we need to make a copy to
@@ -449,10 +497,6 @@ func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) {
obj = localObj.Elem()
}
- // Prepare the value.
- s := &wire.Struct{}
- *dest = s
-
// Look the type up in the database.
te, ok := es.types.Lookup(obj.Type())
if te == nil {
@@ -730,45 +774,43 @@ func (es *encodeState) Save(obj reflect.Value) {
Failf("encoding error at object %#v: %w", oes.obj.Interface(), err)
}
- // Check that items are pending.
- if es.pending.Front() == nil {
+ // Check that we have objects to serialize.
+ if len(es.pending) == 0 {
Failf("pending is empty?")
}
- // Write the header with the number of objects. Note that there is no
- // way that es.lastID could conflict with objectID, which would
- // indicate that an impossibly large encoding.
- if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil {
+ // Write the header with the number of objects.
+ if err := WriteHeader(es.w, uint64(len(es.pending)), true); err != nil {
Failf("error writing header: %w", err)
}
// Serialize all pending types and pending objects. Note that we don't
// bother removing from this list as we walk it because that just
// wastes time. It will not change after this point.
- var id objectID
if err := safely(func() {
for _, wt := range es.pendingTypes {
// Encode the type.
wire.Save(es.w, &wt)
}
- for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() {
- id++ // First object is 1.
- if oes.id != id {
- Failf("expected id %d, got %d", id, oes.id)
- }
-
- // Marshall the object.
+ // Emit objects in ID order.
+ ids := make([]objectID, 0, len(es.pending))
+ for id := range es.pending {
+ ids = append(ids, id)
+ }
+ sort.Slice(ids, func(i, j int) bool {
+ return ids[i] < ids[j]
+ })
+ for _, id := range ids {
+ // Encode the id.
+ wire.Save(es.w, wire.Uint(id))
+ // Marshal the object.
+ oes := es.pending[id]
wire.Save(es.w, oes.encoded)
}
}); err != nil {
// Include the object and the error.
Failf("error serializing object %#v: %w", oes.encoded, err)
}
-
- // Check what we wrote.
- if id != es.lastID {
- Failf("expected %d objects, wrote %d", es.lastID, id)
- }
}
// objectFlag indicates that the length is a # of objects, rather than a raw
@@ -797,11 +839,6 @@ func WriteHeader(w wire.Writer, length uint64, object bool) error {
})
}
-// pendingMapper is for the pending list.
-type pendingMapper struct{}
-
-func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry }
-
// deferredMapper is for the deferred list.
type deferredMapper struct{}
diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go
index cf37aaa49..c6e8bb31d 100644
--- a/pkg/state/pretty/pretty.go
+++ b/pkg/state/pretty/pretty.go
@@ -26,17 +26,23 @@ import (
"gvisor.dev/gvisor/pkg/state/wire"
)
-func formatRef(x *wire.Ref, graph uint64, html bool) string {
+type printer struct {
+ html bool
+ typeSpecs map[string]*wire.Type
+}
+
+func (p *printer) formatRef(x *wire.Ref, graph uint64) string {
baseRef := fmt.Sprintf("g%dr%d", graph, x.Root)
fullRef := baseRef
if len(x.Dots) > 0 {
// See wire.Ref; Type valid if Dots non-zero.
- typ, _ := formatType(x.Type, graph, html)
+ typ, _ := p.formatType(x.Type, graph)
var buf strings.Builder
buf.WriteString("(*")
buf.WriteString(typ)
buf.WriteString(")(")
buf.WriteString(baseRef)
+ buf.WriteString(")")
for _, component := range x.Dots {
switch v := component.(type) {
case *wire.FieldName:
@@ -48,37 +54,42 @@ func formatRef(x *wire.Ref, graph uint64, html bool) string {
panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component)))
}
}
- buf.WriteString(")")
fullRef = buf.String()
}
- if html {
+ if p.html {
return fmt.Sprintf("<a href=\"#%s\">%s</a>", baseRef, fullRef)
}
return fullRef
}
-func formatType(t wire.TypeSpec, graph uint64, html bool) (string, bool) {
+func (p *printer) formatType(t wire.TypeSpec, graph uint64) (string, bool) {
switch x := t.(type) {
case wire.TypeID:
- base := fmt.Sprintf("g%dt%d", graph, x)
- if html {
- return fmt.Sprintf("<a href=\"#%s\">%s</a>", base, base), true
+ tag := fmt.Sprintf("g%dt%d", graph, x)
+ desc := tag
+ if spec, ok := p.typeSpecs[tag]; ok {
+ desc += fmt.Sprintf("=%s", spec.Name)
+ } else {
+ desc += "!missing-type-spec"
}
- return fmt.Sprintf("%s", base), true
+ if p.html {
+ return fmt.Sprintf("<a href=\"#%s\">%s</a>", tag, desc), true
+ }
+ return desc, true
case wire.TypeSpecNil:
return "", false // Only nil type.
case *wire.TypeSpecPointer:
- element, _ := formatType(x.Type, graph, html)
+ element, _ := p.formatType(x.Type, graph)
return fmt.Sprintf("(*%s)", element), true
case *wire.TypeSpecArray:
- element, _ := formatType(x.Type, graph, html)
+ element, _ := p.formatType(x.Type, graph)
return fmt.Sprintf("[%d](%s)", x.Count, element), true
case *wire.TypeSpecSlice:
- element, _ := formatType(x.Type, graph, html)
+ element, _ := p.formatType(x.Type, graph)
return fmt.Sprintf("([]%s)", element), true
case *wire.TypeSpecMap:
- key, _ := formatType(x.Key, graph, html)
- value, _ := formatType(x.Value, graph, html)
+ key, _ := p.formatType(x.Key, graph)
+ value, _ := p.formatType(x.Value, graph)
return fmt.Sprintf("(map[%s]%s)", key, value), true
default:
panic(fmt.Sprintf("unreachable: unknown type %T", t))
@@ -87,7 +98,7 @@ func formatType(t wire.TypeSpec, graph uint64, html bool) (string, bool) {
// format formats a single object, for pretty-printing. It also returns whether
// the value is a non-zero value.
-func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bool) {
+func (p *printer) format(graph uint64, depth int, encoded wire.Object) (string, bool) {
switch x := encoded.(type) {
case wire.Nil:
return "nil", false
@@ -98,7 +109,7 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo
case *wire.Complex128:
return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0
case *wire.Ref:
- return formatRef(x, graph, html), x.Root != 0
+ return p.formatRef(x, graph), x.Root != 0
case *wire.Type:
tabs := "\n" + strings.Repeat("\t", depth)
items := make([]string, 0, len(x.Fields)+2)
@@ -109,7 +120,7 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo
items = append(items, "}")
return strings.Join(items, tabs), true // No zero value.
case *wire.Slice:
- return fmt.Sprintf("%s{len:%d,cap:%d}", formatRef(&x.Ref, graph, html), x.Length, x.Capacity), x.Capacity != 0
+ return fmt.Sprintf("%s{len:%d,cap:%d}", p.formatRef(&x.Ref, graph), x.Length, x.Capacity), x.Capacity != 0
case *wire.Array:
if len(x.Contents) == 0 {
return "[]", false
@@ -119,7 +130,7 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo
items = append(items, "[")
tabs := "\n" + strings.Repeat("\t", depth)
for i := 0; i < len(x.Contents); i++ {
- item, ok := format(graph, depth+1, x.Contents[i], html)
+ item, ok := p.format(graph, depth+1, x.Contents[i])
if !ok {
zeros = append(zeros, fmt.Sprintf("\t%s,", item))
continue
@@ -136,7 +147,9 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo
items = append(items, "]")
return strings.Join(items, tabs), len(zeros) < len(x.Contents)
case *wire.Struct:
- typ, _ := formatType(x.TypeID, graph, html)
+ tag := fmt.Sprintf("g%dt%d", graph, x.TypeID)
+ spec, _ := p.typeSpecs[tag]
+ typ, _ := p.formatType(x.TypeID, graph)
if x.Fields() == 0 {
return fmt.Sprintf("struct[%s]{}", typ), false
}
@@ -145,10 +158,15 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo
tabs := "\n" + strings.Repeat("\t", depth)
allZero := true
for i := 0; i < x.Fields(); i++ {
- element, ok := format(graph, depth+1, *x.Field(i), html)
+ var name string
+ if spec != nil && i < len(spec.Fields) {
+ name = spec.Fields[i]
+ } else {
+ name = fmt.Sprintf("%d", i)
+ }
+ element, ok := p.format(graph, depth+1, *x.Field(i))
allZero = allZero && !ok
- items = append(items, fmt.Sprintf("\t%d: %s,", i, element))
- i++
+ items = append(items, fmt.Sprintf("\t%s: %s,", name, element))
}
items = append(items, "}")
return strings.Join(items, tabs), !allZero
@@ -160,15 +178,15 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo
items = append(items, "map{")
tabs := "\n" + strings.Repeat("\t", depth)
for i := 0; i < len(x.Keys); i++ {
- key, _ := format(graph, depth+1, x.Keys[i], html)
- value, _ := format(graph, depth+1, x.Values[i], html)
+ key, _ := p.format(graph, depth+1, x.Keys[i])
+ value, _ := p.format(graph, depth+1, x.Values[i])
items = append(items, fmt.Sprintf("\t%s: %s,", key, value))
}
items = append(items, "}")
return strings.Join(items, tabs), true
case *wire.Interface:
- typ, typOk := formatType(x.Type, graph, html)
- element, elementOk := format(graph, depth+1, x.Value, html)
+ typ, typOk := p.formatType(x.Type, graph)
+ element, elementOk := p.format(graph, depth+1, x.Value)
return fmt.Sprintf("interface[%s]{%s}", typ, element), typOk || elementOk
default:
// Must be a primitive; use reflection.
@@ -177,11 +195,11 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo
}
// printStream is the basic print implementation.
-func printStream(w io.Writer, r wire.Reader, html bool) (err error) {
+func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) {
// current graph ID.
var graph uint64
- if html {
+ if p.html {
fmt.Fprintf(w, "<pre>")
defer fmt.Fprintf(w, "</pre>")
}
@@ -196,6 +214,8 @@ func printStream(w io.Writer, r wire.Reader, html bool) (err error) {
}
}()
+ p.typeSpecs = make(map[string]*wire.Type)
+
for {
// Find the first object to begin generation.
length, object, err := state.ReadHeader(r)
@@ -222,19 +242,23 @@ func printStream(w io.Writer, r wire.Reader, html bool) (err error) {
// Note that this loop must match the general structure of the
// loop in decode.go. But we don't register type information,
// etc. and just print the raw structures.
+ type objectAndID struct {
+ id uint64
+ obj wire.Object
+ }
var (
- oid uint64 = 1
- tid uint64 = 1
+ tid uint64 = 1
+ objects []objectAndID
)
- for oid <= length {
- // Unmarshal the object.
+ for i := uint64(0); i < length; {
+ // Unmarshal either a type object or object ID.
encoded := wire.Load(r)
-
- // Is this a type?
- if _, ok := encoded.(*wire.Type); ok {
- str, _ := format(graph, 0, encoded, html)
+ switch we := encoded.(type) {
+ case *wire.Type:
+ str, _ := p.format(graph, 0, encoded)
tag := fmt.Sprintf("g%dt%d", graph, tid)
- if html {
+ p.typeSpecs[tag] = we
+ if p.html {
// See below.
tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
}
@@ -242,20 +266,29 @@ func printStream(w io.Writer, r wire.Reader, html bool) (err error) {
return err
}
tid++
- continue
+ case wire.Uint:
+ // Unmarshal the actual object.
+ objects = append(objects, objectAndID{
+ id: uint64(we),
+ obj: wire.Load(r),
+ })
+ i++
+ default:
+ return fmt.Errorf("wanted type or object ID, got %#v", encoded)
}
+ }
+ for _, objAndID := range objects {
// Format the node.
- str, _ := format(graph, 0, encoded, html)
- tag := fmt.Sprintf("g%dr%d", graph, oid)
- if html {
+ str, _ := p.format(graph, 0, objAndID.obj)
+ tag := fmt.Sprintf("g%dr%d", graph, objAndID.id)
+ if p.html {
// Create a little tag with an anchor next to it for linking.
tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
}
if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
return err
}
- oid++
}
}
@@ -264,10 +297,10 @@ func printStream(w io.Writer, r wire.Reader, html bool) (err error) {
// PrintText reads the stream from r and prints text to w.
func PrintText(w io.Writer, r wire.Reader) error {
- return printStream(w, r, false /* html */)
+ return (&printer{}).printStream(w, r)
}
// PrintHTML reads the stream from r and prints html to w.
func PrintHTML(w io.Writer, r wire.Reader) error {
- return printStream(w, r, true /* html */)
+ return (&printer{html: true}).printStream(w, r)
}
diff --git a/pkg/state/state.go b/pkg/state/state.go
index acb629969..6b8540f03 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -90,10 +90,12 @@ func (e *ErrState) Unwrap() error {
func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) {
// Create the encoding state.
es := encodeState{
- ctx: ctx,
- w: w,
- types: makeTypeEncodeDatabase(),
- zeroValues: make(map[reflect.Type]*objectEncodeState),
+ ctx: ctx,
+ w: w,
+ types: makeTypeEncodeDatabase(),
+ zeroValues: make(map[reflect.Type]*objectEncodeState),
+ pending: make(map[objectID]*objectEncodeState),
+ encodedStructs: make(map[reflect.Value]*wire.Struct),
}
// Perform the encoding.
diff --git a/pkg/state/tests/load_test.go b/pkg/state/tests/load_test.go
index 1e9794296..3c73ac391 100644
--- a/pkg/state/tests/load_test.go
+++ b/pkg/state/tests/load_test.go
@@ -20,6 +20,14 @@ import (
func TestLoadHooks(t *testing.T) {
runTestCases(t, false, "load-hooks", []interface{}{
+ // Root object being a struct.
+ afterLoadStruct{v: 1},
+ valueLoadStruct{v: 1},
+ genericContainer{v: &afterLoadStruct{v: 1}},
+ genericContainer{v: &valueLoadStruct{v: 1}},
+ sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}},
+ sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}},
+ // Root object being a pointer.
&afterLoadStruct{v: 1},
&valueLoadStruct{v: 1},
&genericContainer{v: &afterLoadStruct{v: 1}},
diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go
index bd2c2b399..69143d194 100644
--- a/pkg/state/tests/struct.go
+++ b/pkg/state/tests/struct.go
@@ -54,12 +54,47 @@ type outerArray struct {
}
// +stateify savable
+type outerSlice struct {
+ inner []inner
+}
+
+// +stateify savable
type inner struct {
v int64
}
// +stateify savable
+type outerFieldValue struct {
+ inner innerFieldValue
+}
+
+// +stateify savable
+type innerFieldValue struct {
+ v int64 `state:".(*savedFieldValue)"`
+}
+
+// +stateify savable
+type savedFieldValue struct {
+ v int64
+}
+
+func (ifv *innerFieldValue) saveV() *savedFieldValue {
+ return &savedFieldValue{ifv.v}
+}
+
+func (ifv *innerFieldValue) loadV(sfv *savedFieldValue) {
+ ifv.v = sfv.v
+}
+
+// +stateify savable
type system struct {
v1 interface{}
v2 interface{}
}
+
+// +stateify savable
+type system3 struct {
+ v1 interface{}
+ v2 interface{}
+ v3 interface{}
+}
diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go
index de9d17aa7..c91c2c032 100644
--- a/pkg/state/tests/struct_test.go
+++ b/pkg/state/tests/struct_test.go
@@ -15,6 +15,7 @@
package tests
import (
+ "math/rand"
"testing"
"gvisor.dev/gvisor/pkg/state"
@@ -67,12 +68,23 @@ func TestRegisterTypeOnlyStruct(t *testing.T) {
}
func TestEmbeddedPointers(t *testing.T) {
- var (
- ofs outerSame
- of1 outerFieldFirst
- of2 outerFieldSecond
- oa outerArray
- )
+ // Give each int64 a random value to prevent Go from using
+ // runtime.staticuint64s, which confounds tests for struct duplication.
+ magic := func() int64 {
+ for {
+ n := rand.Int63()
+ if n < 0 || n > 255 {
+ return n
+ }
+ }
+ }
+
+ ofs := outerSame{inner{magic()}}
+ of1 := outerFieldFirst{inner{magic()}, magic()}
+ of2 := outerFieldSecond{magic(), inner{magic()}}
+ oa := outerArray{[2]inner{{magic()}, {magic()}}}
+ osl := outerSlice{oa.inner[:]}
+ ofv := outerFieldValue{innerFieldValue{magic()}}
runTestCases(t, false, "embedded-pointers", []interface{}{
system{&ofs, &ofs.inner},
@@ -85,5 +97,15 @@ func TestEmbeddedPointers(t *testing.T) {
system{&oa, &oa.inner[1]},
system{&oa.inner[0], &oa},
system{&oa.inner[1], &oa},
+ system3{&oa, &oa.inner[0], &oa.inner[1]},
+ system3{&oa, &oa.inner[1], &oa.inner[0]},
+ system3{&oa.inner[0], &oa, &oa.inner[1]},
+ system3{&oa.inner[1], &oa, &oa.inner[0]},
+ system3{&oa.inner[0], &oa.inner[1], &oa},
+ system3{&oa.inner[1], &oa.inner[0], &oa},
+ system{&oa, &osl},
+ system{&osl, &oa},
+ system{&ofv, &ofv.inner},
+ system{&ofv.inner, &ofv},
})
}
diff --git a/pkg/state/types.go b/pkg/state/types.go
index 215ef80f8..84aed8732 100644
--- a/pkg/state/types.go
+++ b/pkg/state/types.go
@@ -107,6 +107,14 @@ func lookupNameFields(typ reflect.Type) (string, []string, bool) {
}
return name, nil, true
}
+ // Sanity check the type.
+ if raceEnabled {
+ if _, ok := reverseTypeDatabase[typ]; !ok {
+ // The type was not registered? Must be an embedded
+ // structure or something else.
+ return "", nil, false
+ }
+ }
// Extract the name from the object.
name := t.StateTypeName()
fields := t.StateFields()
@@ -313,6 +321,9 @@ var primitiveTypeDatabase = func() map[string]reflect.Type {
// globalTypeDatabase is used for dispatching interfaces on decode.
var globalTypeDatabase = map[string]reflect.Type{}
+// reverseTypeDatabase is a reverse mapping.
+var reverseTypeDatabase = map[reflect.Type]string{}
+
// Register registers a type.
//
// This must be called on init and only done once.
@@ -358,4 +369,7 @@ func Register(t Type) {
Failf("conflicting name for %T: matches interfaceType", t)
}
globalTypeDatabase[name] = typ
+ if raceEnabled {
+ reverseTypeDatabase[typ] = name
+ }
}