summaryrefslogtreecommitdiffhomepage
path: root/pkg/state
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/state')
-rw-r--r--pkg/state/BUILD100
-rw-r--r--pkg/state/README.md158
-rw-r--r--pkg/state/decode.go725
-rw-r--r--pkg/state/decode_unsafe.go27
-rw-r--r--pkg/state/encode.go841
-rw-r--r--pkg/state/encode_unsafe.go33
-rw-r--r--pkg/state/pretty/BUILD13
-rw-r--r--pkg/state/pretty/pretty.go273
-rw-r--r--pkg/state/state.go321
-rw-r--r--pkg/state/state_norace.go19
-rw-r--r--pkg/state/state_race.go19
-rw-r--r--pkg/state/statefile/BUILD22
-rw-r--r--pkg/state/statefile/statefile.go239
-rw-r--r--pkg/state/statefile/statefile_test.go290
-rw-r--r--pkg/state/stats.go145
-rw-r--r--pkg/state/tests/BUILD43
-rw-r--r--pkg/state/tests/array.go35
-rw-r--r--pkg/state/tests/array_test.go134
-rw-r--r--pkg/state/tests/bench.go24
-rw-r--r--pkg/state/tests/bench_test.go153
-rw-r--r--pkg/state/tests/bool_test.go31
-rw-r--r--pkg/state/tests/float_test.go118
-rw-r--r--pkg/state/tests/integer.go163
-rw-r--r--pkg/state/tests/integer_test.go94
-rw-r--r--pkg/state/tests/load.go61
-rw-r--r--pkg/state/tests/load_test.go70
-rw-r--r--pkg/state/tests/map.go28
-rw-r--r--pkg/state/tests/map_test.go90
-rw-r--r--pkg/state/tests/register.go21
-rw-r--r--pkg/state/tests/register_test.go167
-rw-r--r--pkg/state/tests/string_test.go34
-rw-r--r--pkg/state/tests/struct.go65
-rw-r--r--pkg/state/tests/struct_test.go89
-rw-r--r--pkg/state/tests/tests.go215
-rw-r--r--pkg/state/types.go361
-rw-r--r--pkg/state/wire/BUILD12
-rw-r--r--pkg/state/wire/wire.go970
37 files changed, 6203 insertions, 0 deletions
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
new file mode 100644
index 000000000..089b3bbef
--- /dev/null
+++ b/pkg/state/BUILD
@@ -0,0 +1,100 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "pending_list",
+ out = "pending_list.go",
+ package = "state",
+ prefix = "pending",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*objectEncodeState",
+ "ElementMapper": "pendingMapper",
+ "Linker": "*pendingEntry",
+ },
+)
+
+go_template_instance(
+ name = "deferred_list",
+ out = "deferred_list.go",
+ package = "state",
+ prefix = "deferred",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*objectEncodeState",
+ "ElementMapper": "deferredMapper",
+ "Linker": "*deferredEntry",
+ },
+)
+
+go_template_instance(
+ name = "complete_list",
+ out = "complete_list.go",
+ package = "state",
+ prefix = "complete",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*objectDecodeState",
+ "Linker": "*objectDecodeState",
+ },
+)
+
+go_template_instance(
+ name = "addr_range",
+ out = "addr_range.go",
+ package = "state",
+ prefix = "addr",
+ template = "//pkg/segment:generic_range",
+ types = {
+ "T": "uintptr",
+ },
+)
+
+go_template_instance(
+ name = "addr_set",
+ out = "addr_set.go",
+ consts = {
+ "minDegree": "10",
+ },
+ imports = {
+ "reflect": "reflect",
+ },
+ package = "state",
+ prefix = "addr",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uintptr",
+ "Range": "addrRange",
+ "Value": "*objectEncodeState",
+ "Functions": "addrSetFunctions",
+ },
+)
+
+go_library(
+ name = "state",
+ srcs = [
+ "addr_range.go",
+ "addr_set.go",
+ "complete_list.go",
+ "decode.go",
+ "decode_unsafe.go",
+ "deferred_list.go",
+ "encode.go",
+ "encode_unsafe.go",
+ "pending_list.go",
+ "state.go",
+ "state_norace.go",
+ "state_race.go",
+ "stats.go",
+ "types.go",
+ ],
+ marshal = False,
+ stateify = False,
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/log",
+ "//pkg/state/wire",
+ ],
+)
diff --git a/pkg/state/README.md b/pkg/state/README.md
new file mode 100644
index 000000000..1aa401193
--- /dev/null
+++ b/pkg/state/README.md
@@ -0,0 +1,158 @@
+# State Encoding and Decoding
+
+The state package implements the encoding and decoding of data structures for
+`go_stateify`. This package is designed for use cases other than the standard
+encoding packages, e.g. `gob` and `json`. Principally:
+
+* This package operates on complex object graphs and accurately serializes and
+ restores all relationships. That is, you can have things like: intrusive
+ pointers, cycles, and pointer chains of arbitrary depths. These are not
+ handled appropriately by existing encoders. This is not an implementation
+ flaw: the formats themselves are not capable of representing these graphs,
+ as they can only generate directed trees.
+
+* This package allows installing order-dependent load callbacks and then
+ resolves that graph at load time, with cycle detection. Similarly, there is
+ no analogous feature possible in the standard encoders.
+
+* This package handles the resolution of interfaces, based on a registered
+ type name. For interface objects type information is saved in the serialized
+ format. This is generally true for `gob` as well, but it works differently.
+
+Here's an overview of how encoding and decoding works.
+
+## Encoding
+
+Encoding produces a `statefile`, which contains a list of chunks of the form
+`(header, payload)`. The payload can either be some raw data, or a series of
+encoded wire objects representing some object graph. All encoded objects are
+defined in the `wire` subpackage.
+
+Encoding of an object graph begins with `encodeState.Save`.
+
+### 1. Memory Map & Encoding
+
+To discover relationships between potentially interdependent data structures
+(for example, a struct may contain pointers to members of other data
+structures), the encoder first walks the object graph and constructs a memory
+map of the objects in the input graph. As this walk progresses, objects are
+queued in the `pending` list and items are placed on the `deferred` list as they
+are discovered. No single object will be encoded multiple times, but the
+discovered relationships between objects may change as more parts of the overall
+object graph are discovered.
+
+The encoder starts at the root object and recursively visits all reachable
+objects, recording the address ranges containing the underlying data for each
+object. This is stored as a segment set (`addrSet`), mapping address ranges to
+the of the object occupying the range; see `encodeState.values`. Note that there
+is special handling for zero-sized types and map objects during this process.
+
+Additionally, the encoder assigns each object a unique identifier which is used
+to indicate relationships between objects in the statefile; see `objectID` in
+`encode.go`.
+
+### 2. Type Serialization
+
+The enoder will subsequently serialize all information about discovered types,
+including field names. These are used during decoding to reconcile these types
+with other internally registered types.
+
+### 3. Object Serialization
+
+With a full address map, and all objects correctly encoded, all object encodings
+are serialized. The assigned `objectID`s aren't explicitly encoded in the
+statefile. The order of object messages in the stream determine their IDs.
+
+### Example
+
+Given the following data structure definitions:
+
+```go
+type system struct {
+ o *outer
+ i *inner
+}
+
+type outer struct {
+ a int64
+ cn *container
+}
+
+type container struct {
+ n uint64
+ elem *inner
+}
+
+type inner struct {
+ c container
+ x, y uint64
+}
+```
+
+Initialized like this:
+
+```go
+o := outer{
+ a: 10,
+ cn: nil,
+}
+i := inner{
+ x: 20,
+ y: 30,
+ c: container{},
+}
+s := system{
+ o: &o,
+ i: &i,
+}
+
+o.cn = &i.c
+o.cn.elem = &i
+
+```
+
+Encoding will produce an object stream like this:
+
+```
+g0r1 = struct{
+ i: g0r3,
+ o: g0r2,
+}
+g0r2 = struct{
+ a: 10,
+ cn: g0r3.c,
+}
+g0r3 = struct{
+ c: struct{
+ elem: g0r3,
+ n: 0u,
+ },
+ x: 20u,
+ y: 30u,
+}
+```
+
+Note how `g0r3.c` is correctly encoded as the underlying `container` object for
+`inner.c`, and how the pointer from `outer.cn` points to it, despite `system.i`
+being discovered after the pointer to it in `system.o.cn`. Also note that
+decoding isn't strictly reliant on the order of encoded object stream, as long
+as the relationship between objects are correctly encoded.
+
+## Decoding
+
+Decoding reads the statefile and reconstructs the object graph. Decoding begins
+in `decodeState.Load`. Decoding is performed in a single pass over the object
+stream in the statefile, and a subsequent pass over all deserialized objects is
+done to fire off all loading callbacks in the correctly defined order. Note that
+introducing cycles is possible here, but these are detected and an error will be
+returned.
+
+Decoding is relatively straight forward. For most primitive values, the decoder
+constructs an appropriate object and fills it with the values encoded in the
+statefile. Pointers need special handling, as they must point to a value
+allocated elsewhere. When values are constructed, the decoder indexes them by
+their `objectID`s in `decodeState.objectsByID`. The target of pointers are
+resolved by searching for the target in this index by their `objectID`; see
+`decodeState.register`. For pointers to values inside another value (fields in a
+pointer, elements of an array), the decoder uses the accessor path to walk to
+the appropriate location; see `walkChild`.
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
new file mode 100644
index 000000000..c9971cdf6
--- /dev/null
+++ b/pkg/state/decode.go
@@ -0,0 +1,725 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "math"
+ "reflect"
+
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// internalCallback is a interface called on object completion.
+//
+// There are two implementations: objectDecodeState & userCallback.
+type internalCallback interface {
+ // source returns the dependent object. May be nil.
+ source() *objectDecodeState
+
+ // callbackRun executes the callback.
+ callbackRun()
+}
+
+// userCallback is an implementation of internalCallback.
+type userCallback func()
+
+// source implements internalCallback.source.
+func (userCallback) source() *objectDecodeState {
+ return nil
+}
+
+// callbackRun implements internalCallback.callbackRun.
+func (uc userCallback) callbackRun() {
+ uc()
+}
+
+// objectDecodeState represents an object that may be in the process of being
+// decoded. Specifically, it represents either a decoded object, or an an
+// interest in a future object that will be decoded. When that interest is
+// registered (via register), the storage for the object will be created, but
+// it will not be decoded until the object is encountered in the stream.
+type objectDecodeState struct {
+ // id is the id for this object.
+ id objectID
+
+ // typ is the id for this typeID. This may be zero if this is not a
+ // type-registered structure.
+ typ typeID
+
+ // obj is the object. This may or may not be valid yet, depending on
+ // whether complete returns true. However, regardless of whether the
+ // object is valid, obj contains a final storage location for the
+ // object. This is immutable.
+ //
+ // Note that this must be addressable (obj.Addr() must not panic).
+ //
+ // The obj passed to the decode methods below will equal this obj only
+ // in the case of decoding the top-level object. However, the passed
+ // obj may represent individual fields, elements of a slice, etc. that
+ // are effectively embedded within the reflect.Value below but with
+ // distinct types.
+ obj reflect.Value
+
+ // blockedBy is the number of dependencies this object has.
+ blockedBy int
+
+ // callbacksInline is inline storage for callbacks.
+ callbacksInline [2]internalCallback
+
+ // callbacks is a set of callbacks to execute on load.
+ callbacks []internalCallback
+
+ completeEntry
+}
+
+// addCallback adds a callback to the objectDecodeState.
+func (ods *objectDecodeState) addCallback(ic internalCallback) {
+ if ods.callbacks == nil {
+ ods.callbacks = ods.callbacksInline[:0]
+ }
+ ods.callbacks = append(ods.callbacks, ic)
+}
+
+// findCycleFor returns when the given object is found in the blocking set.
+func (ods *objectDecodeState) findCycleFor(target *objectDecodeState) []*objectDecodeState {
+ for _, ic := range ods.callbacks {
+ other := ic.source()
+ if other != nil && other == target {
+ return []*objectDecodeState{target}
+ } else if childList := other.findCycleFor(target); childList != nil {
+ return append(childList, other)
+ }
+ }
+
+ // This should not occur.
+ Failf("no deadlock found?")
+ panic("unreachable")
+}
+
+// findCycle finds a dependency cycle.
+func (ods *objectDecodeState) findCycle() []*objectDecodeState {
+ return append(ods.findCycleFor(ods), ods)
+}
+
+// source implements internalCallback.source.
+func (ods *objectDecodeState) source() *objectDecodeState {
+ return ods
+}
+
+// callbackRun implements internalCallback.callbackRun.
+func (ods *objectDecodeState) callbackRun() {
+ ods.blockedBy--
+}
+
+// decodeState is a graph of objects in the process of being decoded.
+//
+// The decode process involves loading the breadth-first graph generated by
+// encode. This graph is read in it's entirety, ensuring that all object
+// storage is complete.
+//
+// As the graph is being serialized, a set of completion callbacks are
+// executed. These completion callbacks should form a set of acyclic subgraphs
+// over the original one. After decoding is complete, the objects are scanned
+// to ensure that all callbacks are executed, otherwise the callback graph was
+// not acyclic.
+type decodeState struct {
+ // ctx is the decode context.
+ ctx context.Context
+
+ // r is the input stream.
+ r wire.Reader
+
+ // types is the type database.
+ types typeDecodeDatabase
+
+ // objectByID is the set of objects in progress.
+ objectsByID []*objectDecodeState
+
+ // deferred are objects that have been read, by no interest has been
+ // registered yet. These will be decoded once interest in registered.
+ deferred map[objectID]wire.Object
+
+ // pending is the set of objects that are not yet complete.
+ pending completeList
+
+ // stats tracks time data.
+ stats Stats
+}
+
+// lookup looks up an object in decodeState or returns nil if no such object
+// has been previously registered.
+func (ds *decodeState) lookup(id objectID) *objectDecodeState {
+ if len(ds.objectsByID) < int(id) {
+ return nil
+ }
+ return ds.objectsByID[id-1]
+}
+
+// checkComplete checks for completion.
+func (ds *decodeState) checkComplete(ods *objectDecodeState) bool {
+ // Still blocked?
+ if ods.blockedBy > 0 {
+ return false
+ }
+
+ // Track stats if relevant.
+ if ods.callbacks != nil && ods.typ != 0 {
+ ds.stats.start(ods.typ)
+ defer ds.stats.done()
+ }
+
+ // Fire all callbacks.
+ for _, ic := range ods.callbacks {
+ ic.callbackRun()
+ }
+
+ // Mark completed.
+ cbs := ods.callbacks
+ ods.callbacks = nil
+ ds.pending.Remove(ods)
+
+ // Recursively check others.
+ for _, ic := range cbs {
+ if other := ic.source(); other != nil && other.blockedBy == 0 {
+ ds.checkComplete(other)
+ }
+ }
+
+ return true // All set.
+}
+
+// wait registers a dependency on an object.
+//
+// As a special case, we always allow _useable_ references back to the first
+// decoding object because it may have fields that are already decoded. We also
+// allow trivial self reference, since they can be handled internally.
+func (ds *decodeState) wait(waiter *objectDecodeState, id objectID, callback func()) {
+ switch id {
+ case waiter.id:
+ // Trivial self reference.
+ fallthrough
+ case 1:
+ // Root object; see above.
+ if callback != nil {
+ callback()
+ }
+ return
+ }
+
+ // Mark as blocked.
+ waiter.blockedBy++
+
+ // No nil can be returned here.
+ other := ds.lookup(id)
+ if callback != nil {
+ // Add the additional user callback.
+ other.addCallback(userCallback(callback))
+ }
+
+ // Mark waiter as unblocked.
+ other.addCallback(waiter)
+}
+
+// waitObject notes a blocking relationship.
+func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, callback func()) {
+ if rv, ok := encoded.(*wire.Ref); ok && rv.Root != 0 {
+ // Refs can encode pointers and maps.
+ ds.wait(ods, objectID(rv.Root), callback)
+ } else if sv, ok := encoded.(*wire.Slice); ok && sv.Ref.Root != 0 {
+ // See decodeObject; we need to wait for the array (if non-nil).
+ ds.wait(ods, objectID(sv.Ref.Root), callback)
+ } else if iv, ok := encoded.(*wire.Interface); ok {
+ // It's an interface (wait recurisvely).
+ ds.waitObject(ods, iv.Value, callback)
+ } else if callback != nil {
+ // Nothing to wait for: execute the callback immediately.
+ callback()
+ }
+}
+
+// walkChild returns a child object from obj, given an accessor path. This is
+// the decode-side equivalent to traverse in encode.go.
+//
+// For the purposes of this function, a child object is either a field within a
+// struct or an array element, with one such indirection per element in
+// path. The returned value may be an unexported field, so it may not be
+// directly assignable. See unsafePointerTo.
+func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value {
+ // See wire.Ref.Dots. The path here is specified in reverse order.
+ for i := len(path) - 1; i >= 0; i-- {
+ switch pc := path[i].(type) {
+ case *wire.FieldName: // Must be a pointer.
+ if obj.Kind() != reflect.Struct {
+ Failf("next component in child path is a field name, but the current object is not a struct. Path: %v, current obj: %#v", path, obj)
+ }
+ obj = obj.FieldByName(string(*pc))
+ case wire.Index: // Embedded.
+ if obj.Kind() != reflect.Array {
+ Failf("next component in child path is an array index, but the current object is not an array. Path: %v, current obj: %#v", path, obj)
+ }
+ obj = obj.Index(int(pc))
+ default:
+ panic("unreachable: switch should be exhaustive")
+ }
+ }
+ return obj
+}
+
+// register registers a decode with a type.
+//
+// This type is only used to instantiate a new object if it has not been
+// registered previously. This depends on the type provided if none is
+// available in the object itself.
+func (ds *decodeState) register(r *wire.Ref, typ reflect.Type) reflect.Value {
+ // Grow the objectsByID slice.
+ id := objectID(r.Root)
+ if len(ds.objectsByID) < int(id) {
+ ds.objectsByID = append(ds.objectsByID, make([]*objectDecodeState, int(id)-len(ds.objectsByID))...)
+ }
+
+ // Does this object already exist?
+ ods := ds.objectsByID[id-1]
+ if ods != nil {
+ return walkChild(r.Dots, ods.obj)
+ }
+
+ // Create the object.
+ if len(r.Dots) != 0 {
+ typ = ds.findType(r.Type)
+ }
+ v := reflect.New(typ)
+ ods = &objectDecodeState{
+ id: id,
+ obj: v.Elem(),
+ }
+ ds.objectsByID[id-1] = ods
+ ds.pending.PushBack(ods)
+
+ // Process any deferred objects & callbacks.
+ if encoded, ok := ds.deferred[id]; ok {
+ delete(ds.deferred, id)
+ ds.decodeObject(ods, ods.obj, encoded)
+ }
+
+ return walkChild(r.Dots, ods.obj)
+}
+
+// objectDecoder is for decoding structs.
+type objectDecoder struct {
+ // ds is decodeState.
+ ds *decodeState
+
+ // ods is current object being decoded.
+ ods *objectDecodeState
+
+ // reconciledTypeEntry is the reconciled type information.
+ rte *reconciledTypeEntry
+
+ // encoded is the encoded object state.
+ encoded *wire.Struct
+}
+
+// load is helper for the public methods on Source.
+func (od *objectDecoder) load(slot int, objPtr reflect.Value, wait bool, fn func()) {
+ // Note that we have reconciled the type and may remap the fields here
+ // to match what's expected by the decoder. The "slot" parameter here
+ // is in terms of the local type, where the fields in the encoded
+ // object are in terms of the wire object's type, which might be in a
+ // different order (but will have the same fields).
+ v := *od.encoded.Field(od.rte.FieldOrder[slot])
+ od.ds.decodeObject(od.ods, objPtr.Elem(), v)
+ if wait {
+ // Mark this individual object a blocker.
+ od.ds.waitObject(od.ods, v, fn)
+ }
+}
+
+// aterLoad implements Source.AfterLoad.
+func (od *objectDecoder) afterLoad(fn func()) {
+ // Queue the local callback; this will execute when all of the above
+ // data dependencies have been cleared.
+ od.ods.addCallback(userCallback(fn))
+}
+
+// decodeStruct decodes a struct value.
+func (ds *decodeState) decodeStruct(ods *objectDecodeState, obj reflect.Value, encoded *wire.Struct) {
+ if encoded.TypeID == 0 {
+ // Allow anonymous empty structs, but only if the encoded
+ // object also has no fields.
+ if encoded.Fields() == 0 && obj.NumField() == 0 {
+ return
+ }
+
+ // Propagate an error.
+ Failf("empty struct on wire %#v has field mismatch with type %q", encoded, obj.Type().Name())
+ }
+
+ // Lookup the object type.
+ rte := ds.types.Lookup(typeID(encoded.TypeID), obj.Type())
+ ods.typ = typeID(encoded.TypeID)
+
+ // Invoke the loader.
+ od := objectDecoder{
+ ds: ds,
+ ods: ods,
+ rte: rte,
+ encoded: encoded,
+ }
+ ds.stats.start(ods.typ)
+ defer ds.stats.done()
+ if sl, ok := obj.Addr().Interface().(SaverLoader); ok {
+ // Note: may be a registered empty struct which does not
+ // implement the saver/loader interfaces.
+ sl.StateLoad(Source{internal: od})
+ }
+}
+
+// decodeMap decodes a map value.
+func (ds *decodeState) decodeMap(ods *objectDecodeState, obj reflect.Value, encoded *wire.Map) {
+ if obj.IsNil() {
+ // See pointerTo.
+ obj.Set(reflect.MakeMap(obj.Type()))
+ }
+ for i := 0; i < len(encoded.Keys); i++ {
+ // Decode the objects.
+ kv := reflect.New(obj.Type().Key()).Elem()
+ vv := reflect.New(obj.Type().Elem()).Elem()
+ ds.decodeObject(ods, kv, encoded.Keys[i])
+ ds.decodeObject(ods, vv, encoded.Values[i])
+ ds.waitObject(ods, encoded.Keys[i], nil)
+ ds.waitObject(ods, encoded.Values[i], nil)
+
+ // Set in the map.
+ obj.SetMapIndex(kv, vv)
+ }
+}
+
+// decodeArray decodes an array value.
+func (ds *decodeState) decodeArray(ods *objectDecodeState, obj reflect.Value, encoded *wire.Array) {
+ if len(encoded.Contents) != obj.Len() {
+ Failf("mismatching array length expect=%d, actual=%d", obj.Len(), len(encoded.Contents))
+ }
+ // Decode the contents into the array.
+ for i := 0; i < len(encoded.Contents); i++ {
+ ds.decodeObject(ods, obj.Index(i), encoded.Contents[i])
+ ds.waitObject(ods, encoded.Contents[i], nil)
+ }
+}
+
+// findType finds the type for the given wire.TypeSpecs.
+func (ds *decodeState) findType(t wire.TypeSpec) reflect.Type {
+ switch x := t.(type) {
+ case wire.TypeID:
+ typ := ds.types.LookupType(typeID(x))
+ rte := ds.types.Lookup(typeID(x), typ)
+ return rte.LocalType
+ case *wire.TypeSpecPointer:
+ return reflect.PtrTo(ds.findType(x.Type))
+ case *wire.TypeSpecArray:
+ return reflect.ArrayOf(int(x.Count), ds.findType(x.Type))
+ case *wire.TypeSpecSlice:
+ return reflect.SliceOf(ds.findType(x.Type))
+ case *wire.TypeSpecMap:
+ return reflect.MapOf(ds.findType(x.Key), ds.findType(x.Value))
+ default:
+ // Should not happen.
+ Failf("unknown type %#v", t)
+ }
+ panic("unreachable")
+}
+
+// decodeInterface decodes an interface value.
+func (ds *decodeState) decodeInterface(ods *objectDecodeState, obj reflect.Value, encoded *wire.Interface) {
+ if _, ok := encoded.Type.(wire.TypeSpecNil); ok {
+ // Special case; the nil object. Just decode directly, which
+ // will read nil from the wire (if encoded correctly).
+ ds.decodeObject(ods, obj, encoded.Value)
+ return
+ }
+
+ // We now need to resolve the actual type.
+ typ := ds.findType(encoded.Type)
+
+ // We need to imbue type information here, then we can proceed to
+ // decode normally. In order to avoid issues with setting value-types,
+ // we create a new non-interface version of this object. We will then
+ // set the interface object to be equal to whatever we decode.
+ origObj := obj
+ obj = reflect.New(typ).Elem()
+ defer origObj.Set(obj)
+
+ // With the object now having sufficient type information to actually
+ // have Set called on it, we can proceed to decode the value.
+ ds.decodeObject(ods, obj, encoded.Value)
+}
+
+// isFloatEq determines if x and y represent the same value.
+func isFloatEq(x float64, y float64) bool {
+ switch {
+ case math.IsNaN(x):
+ return math.IsNaN(y)
+ case math.IsInf(x, 1):
+ return math.IsInf(y, 1)
+ case math.IsInf(x, -1):
+ return math.IsInf(y, -1)
+ default:
+ return x == y
+ }
+}
+
+// isComplexEq determines if x and y represent the same value.
+func isComplexEq(x complex128, y complex128) bool {
+ return isFloatEq(real(x), real(y)) && isFloatEq(imag(x), imag(y))
+}
+
+// decodeObject decodes a object value.
+func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, encoded wire.Object) {
+ switch x := encoded.(type) {
+ case wire.Nil: // Fast path: first.
+ // We leave obj alone here. That's because if obj represents an
+ // interface, it may have been imbued with type information in
+ // decodeInterface, and we don't want to destroy that.
+ case *wire.Ref:
+ // Nil pointers may be encoded in a "forceValue" context. For
+ // those we just leave it alone as the value will already be
+ // correct (nil).
+ if id := objectID(x.Root); id == 0 {
+ return
+ }
+
+ // Note that if this is a map type, we go through a level of
+ // indirection to allow for map aliasing.
+ if obj.Kind() == reflect.Map {
+ v := ds.register(x, obj.Type())
+ if v.IsNil() {
+ // Note that we don't want to clobber the map
+ // if has already been decoded by decodeMap. We
+ // just make it so that we have a consistent
+ // reference when that eventually does happen.
+ v.Set(reflect.MakeMap(v.Type()))
+ }
+ obj.Set(v)
+ return
+ }
+
+ // Normal assignment: authoritative only if no dots.
+ v := ds.register(x, obj.Type().Elem())
+ if v.IsValid() {
+ obj.Set(unsafePointerTo(v))
+ }
+ case wire.Bool:
+ obj.SetBool(bool(x))
+ case wire.Int:
+ obj.SetInt(int64(x))
+ if obj.Int() != int64(x) {
+ Failf("signed integer truncated from %v to %v", int64(x), obj.Int())
+ }
+ case wire.Uint:
+ obj.SetUint(uint64(x))
+ if obj.Uint() != uint64(x) {
+ Failf("unsigned integer truncated from %v to %v", uint64(x), obj.Uint())
+ }
+ case wire.Float32:
+ obj.SetFloat(float64(x))
+ case wire.Float64:
+ obj.SetFloat(float64(x))
+ if !isFloatEq(obj.Float(), float64(x)) {
+ Failf("floating point number truncated from %v to %v", float64(x), obj.Float())
+ }
+ case *wire.Complex64:
+ obj.SetComplex(complex128(*x))
+ case *wire.Complex128:
+ obj.SetComplex(complex128(*x))
+ if !isComplexEq(obj.Complex(), complex128(*x)) {
+ Failf("complex number truncated from %v to %v", complex128(*x), obj.Complex())
+ }
+ case *wire.String:
+ obj.SetString(string(*x))
+ case *wire.Slice:
+ // See *wire.Ref above; same applies.
+ if id := objectID(x.Ref.Root); id == 0 {
+ return
+ }
+ // Note that it's fine to slice the array here and assume that
+ // contents will still be filled in later on.
+ typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type.
+ v := ds.register(&x.Ref, typ)
+ obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity)))
+ case *wire.Array:
+ ds.decodeArray(ods, obj, x)
+ case *wire.Struct:
+ ds.decodeStruct(ods, obj, x)
+ case *wire.Map:
+ ds.decodeMap(ods, obj, x)
+ case *wire.Interface:
+ ds.decodeInterface(ods, obj, x)
+ default:
+ // Shoud not happen, not propagated as an error.
+ Failf("unknown object %#v for %q", encoded, obj.Type().Name())
+ }
+}
+
+// Load deserializes the object graph rooted at obj.
+//
+// This function may panic and should be run in safely().
+func (ds *decodeState) Load(obj reflect.Value) {
+ ds.stats.init()
+ defer ds.stats.fini(func(id typeID) string {
+ return ds.types.LookupName(id)
+ })
+
+ // Create the root object.
+ ds.objectsByID = append(ds.objectsByID, &objectDecodeState{
+ id: 1,
+ obj: obj,
+ })
+
+ // Read the number of objects.
+ lastID, object, err := ReadHeader(ds.r)
+ if err != nil {
+ Failf("header error: %w", err)
+ }
+ if !object {
+ Failf("object missing")
+ }
+
+ // Decode all objects.
+ var (
+ encoded wire.Object
+ ods *objectDecodeState
+ id = objectID(1)
+ tid = typeID(1)
+ )
+ if err := safely(func() {
+ // Decode all objects in the stream.
+ //
+ // Note that the structure of this decoding loop should match
+ // the raw decoding loop in printer.go.
+ for id <= objectID(lastID) {
+ // Unmarshal the object.
+ encoded = wire.Load(ds.r)
+
+ // Is this a type object? Handle inline.
+ if wt, ok := encoded.(*wire.Type); ok {
+ ds.types.Register(wt)
+ tid++
+ encoded = nil
+ continue
+ }
+
+ // Actually resolve the object.
+ ods = ds.lookup(id)
+ if ods != nil {
+ // Decode the object.
+ ds.decodeObject(ods, ods.obj, encoded)
+ } else {
+ // If an object hasn't had interest registered
+ // previously or isn't yet valid, we deferred
+ // decoding until interest is registered.
+ ds.deferred[id] = encoded
+ }
+
+ // For error handling.
+ ods = nil
+ encoded = nil
+ id++
+ }
+ }); err != nil {
+ // Include as much information as we can, taking into account
+ // the possible state transitions above.
+ if ods != nil {
+ Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err)
+ } else if encoded != nil {
+ Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err)
+ } else {
+ Failf("general decoding error: %w", err)
+ }
+ }
+
+ // Check if we have any deferred objects.
+ for id, encoded := range ds.deferred {
+ // Shoud never happen, the graph was bogus.
+ Failf("still have deferred objects: one is ID %d, %#v", id, encoded)
+ }
+
+ // Scan and fire all callbacks. We iterate over the list of incomplete
+ // objects until all have been finished. We stop iterating if no
+ // objects become complete (there is a dependency cycle).
+ //
+ // Note that we iterate backwards here, because there will be a strong
+ // tendendcy for blocking relationships to go from earlier objects to
+ // later (deeper) objects in the graph. This will reduce the number of
+ // iterations required to finish all objects.
+ if err := safely(func() {
+ for ds.pending.Back() != nil {
+ thisCycle := false
+ for ods = ds.pending.Back(); ods != nil; {
+ if ds.checkComplete(ods) {
+ thisCycle = true
+ break
+ }
+ ods = ods.Prev()
+ }
+ if !thisCycle {
+ break
+ }
+ }
+ }); err != nil {
+ Failf("error executing callbacks for %#v: %w", ods.obj.Interface(), err)
+ }
+
+ // Check if we have any remaining dependency cycles. If there are any
+ // objects left in the pending list, then it must be due to a cycle.
+ if ods := ds.pending.Front(); ods != nil {
+ // This must be the result of a dependency cycle.
+ cycle := ods.findCycle()
+ var buf bytes.Buffer
+ buf.WriteString("dependency cycle: {")
+ for i, cycleOS := range cycle {
+ if i > 0 {
+ buf.WriteString(" => ")
+ }
+ fmt.Fprintf(&buf, "%q", cycleOS.obj.Type())
+ }
+ buf.WriteString("}")
+ Failf("incomplete graph: %s", string(buf.Bytes()))
+ }
+}
+
+// ReadHeader reads an object header.
+//
+// Each object written to the statefile is prefixed with a header. See
+// WriteHeader for more information; these functions are exported to allow
+// non-state writes to the file to play nice with debugging tools.
+func ReadHeader(r wire.Reader) (length uint64, object bool, err error) {
+ // Read the header.
+ err = safely(func() {
+ length = wire.LoadUint(r)
+ })
+ if err != nil {
+ // On the header, pass raw I/O errors.
+ if sErr, ok := err.(*ErrState); ok {
+ return 0, false, sErr.Unwrap()
+ }
+ }
+
+ // Decode whether the object is valid.
+ object = length&objectFlag != 0
+ length &^= objectFlag
+ return
+}
diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go
new file mode 100644
index 000000000..d048f61a1
--- /dev/null
+++ b/pkg/state/decode_unsafe.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on
+// values representing unexported fields. This bypasses visibility, but not
+// type safety.
+func unsafePointerTo(obj reflect.Value) reflect.Value {
+ return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr()))
+}
diff --git a/pkg/state/encode.go b/pkg/state/encode.go
new file mode 100644
index 000000000..92fcad4e9
--- /dev/null
+++ b/pkg/state/encode.go
@@ -0,0 +1,841 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "context"
+ "reflect"
+
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// objectEncodeState the type and identity of an object occupying a memory
+// address range. This is the value type for addrSet, and the intrusive entry
+// for the pending and deferred lists.
+type objectEncodeState struct {
+ // id is the assigned ID for this object.
+ id objectID
+
+ // obj is the object value. Note that this may be replaced if we
+ // encounter an object that contains this object. When this happens (in
+ // resolve), we will update existing references approprately, below,
+ // and defer a re-encoding of the object.
+ obj reflect.Value
+
+ // encoded is the encoded value of this object. Note that this may not
+ // be up to date if this object is still in the deferred list.
+ encoded wire.Object
+
+ // how indicates whether this object should be encoded as a value. This
+ // is used only for deferred encoding.
+ how encodeStrategy
+
+ // refs are the list of reference objects used by other objects
+ // referring to this object. When the object is updated, these
+ // references may be updated directly and automatically.
+ refs []*wire.Ref
+
+ pendingEntry
+ deferredEntry
+}
+
+// encodeState is state used for encoding.
+//
+// The encoding process constructs a representation of the in-memory graph of
+// objects before a single object is serialized. This is done to ensure that
+// all references can be fully disambiguated. See resolve for more details.
+type encodeState struct {
+ // ctx is the encode context.
+ ctx context.Context
+
+ // w is the output stream.
+ w wire.Writer
+
+ // types is the type database.
+ types typeEncodeDatabase
+
+ // lastID is the last allocated object ID.
+ lastID objectID
+
+ // values tracks the address ranges occupied by objects, along with the
+ // types of these objects. This is used to locate pointer targets,
+ // including pointers to fields within another type.
+ //
+ // Multiple objects may overlap in memory iff the larger object fully
+ // contains the smaller one, and the type of the smaller object matches
+ // a field or array element's type at the appropriate offset. An
+ // arbitrary number of objects may be nested in this manner.
+ //
+ // Note that this does not track zero-sized objects, those are tracked
+ // by zeroValues below.
+ values addrSet
+
+ // zeroValues tracks zero-sized objects.
+ zeroValues map[reflect.Type]*objectEncodeState
+
+ // deferred is the list of objects to be encoded.
+ deferred deferredList
+
+ // pendingTypes is the list of types to be serialized. Serialization
+ // will occur when all objects have been encoded, but before pending is
+ // serialized.
+ pendingTypes []wire.Type
+
+ // pending is the list of objects to be serialized. Serialization does
+ // not actually occur until the full object graph is computed.
+ pending pendingList
+
+ // stats tracks time data.
+ stats Stats
+}
+
+// isSameSizeParent returns true if child is a field value or element within
+// parent. Only a struct or array can have a child value.
+//
+// isSameSizeParent deals with objects like this:
+//
+// struct child {
+// // fields..
+// }
+//
+// struct parent {
+// c child
+// }
+//
+// var p parent
+// record(&p.c)
+//
+// Here, &p and &p.c occupy the exact same address range.
+//
+// Or like this:
+//
+// struct child {
+// // fields
+// }
+//
+// var arr [1]parent
+// record(&arr[0])
+//
+// Similarly, &arr[0] and &arr[0].c have the exact same address range.
+//
+// Precondition: parent and child must occupy the same memory.
+func isSameSizeParent(parent reflect.Value, childType reflect.Type) bool {
+ switch parent.Kind() {
+ case reflect.Struct:
+ for i := 0; i < parent.NumField(); i++ {
+ field := parent.Field(i)
+ if field.Type() == childType {
+ return true
+ }
+ // Recurse through any intermediate types.
+ if isSameSizeParent(field, childType) {
+ return true
+ }
+ // Does it make sense to keep going if the first field
+ // doesn't match? Yes, because there might be an
+ // arbitrary number of zero-sized fields before we get
+ // a match, and childType itself can be zero-sized.
+ }
+ return false
+ case reflect.Array:
+ // The only case where an array with more than one elements can
+ // return true is if childType is zero-sized. In such cases,
+ // it's ambiguous which element contains the match since a
+ // zero-sized child object fully fits in any of the zero-sized
+ // elements in an array... However since all elements are of
+ // the same type, we only need to check one element.
+ //
+ // For non-zero-sized childTypes, parent.Len() must be 1, but a
+ // combination of the precondition and an implicit comparison
+ // between the array element size and childType ensures this.
+ return parent.Len() > 0 && isSameSizeParent(parent.Index(0), childType)
+ default:
+ return false
+ }
+}
+
+// nextID returns the next valid ID.
+func (es *encodeState) nextID() objectID {
+ es.lastID++
+ return objectID(es.lastID)
+}
+
+// dummyAddr points to the dummy zero-sized address.
+var dummyAddr = reflect.ValueOf(new(struct{})).Pointer()
+
+// resolve records the address range occupied by an object.
+func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
+ addr := obj.Pointer()
+
+ // Is this a map pointer? Just record the single address. It is not
+ // possible to take any pointers into the map internals.
+ if obj.Kind() == reflect.Map {
+ if addr == 0 {
+ // Just leave the nil reference alone. This is fine, we
+ // may need to encode as a reference in this way. We
+ // return nil for our objectEncodeState so that anyone
+ // depending on this value knows there's nothing there.
+ return
+ }
+ if seg, _ := es.values.Find(addr); seg.Ok() {
+ // Ensure the map types match.
+ existing := seg.Value()
+ if existing.obj.Type() != obj.Type() {
+ Failf("overlapping map objects at 0x%x: [new object] %#v [existing object type] %s", addr, obj, existing.obj)
+ }
+
+ // No sense recording refs, maps may not be replaced by
+ // covering objects, they are maximal.
+ ref.Root = wire.Uint(existing.id)
+ return
+ }
+
+ // Record the map.
+ oes := &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ how: encodeMapAsValue,
+ }
+ es.values.Add(addrRange{addr, addr + 1}, oes)
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+
+ // See above: no ref recording.
+ ref.Root = wire.Uint(oes.id)
+ return
+ }
+
+ // If not a map, then the object must be a pointer.
+ if obj.Kind() != reflect.Ptr {
+ Failf("attempt to record non-map and non-pointer object %#v", obj)
+ }
+
+ obj = obj.Elem() // Value from here.
+
+ // Is this a zero-sized type?
+ typ := obj.Type()
+ size := typ.Size()
+ if size == 0 {
+ if addr == dummyAddr {
+ // Zero-sized objects point to a dummy byte within the
+ // runtime. There's no sense recording this in the
+ // address map. We add this to the dedicated
+ // zeroValues.
+ //
+ // Note that zero-sized objects must be *true*
+ // zero-sized objects. They cannot be part of some
+ // larger object. In that case, they are assigned a
+ // 1-byte address at the end of the object.
+ oes, ok := es.zeroValues[typ]
+ if !ok {
+ oes = &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ }
+ es.zeroValues[typ] = oes
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+ }
+
+ // There's also no sense tracking back references. We
+ // know that this is a true zero-sized object, and not
+ // part of a larger container, so it will not change.
+ ref.Root = wire.Uint(oes.id)
+ return
+ }
+ size = 1 // See above.
+ }
+
+ // Calculate the container.
+ end := addr + size
+ r := addrRange{addr, end}
+ if seg, _ := es.values.Find(addr); seg.Ok() {
+ existing := seg.Value()
+ switch {
+ case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type():
+ // The object is a perfect match. Happy path. Avoid the
+ // traversal and just return directly. We don't need to
+ // encode the type information or any dots here.
+ ref.Root = wire.Uint(existing.id)
+ existing.refs = append(existing.refs, ref)
+ return
+
+ case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end):
+ // The previously registered object is larger than
+ // this, no need to update. But we expect some
+ // traversal below.
+
+ case seg.Start() == addr && seg.End() == end:
+ if !isSameSizeParent(obj, existing.obj.Type()) {
+ break // Needs traversal.
+ }
+ fallthrough // Needs update.
+
+ case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end):
+ // Update the object and redo the encoding.
+ old := existing.obj
+ existing.obj = obj
+ es.deferred.Remove(existing)
+ es.deferred.PushBack(existing)
+
+ // The previously registered object is superseded by
+ // this new object. We are guaranteed to not have any
+ // mergeable neighbours in this segment set.
+ if !raceEnabled {
+ seg.SetRangeUnchecked(r)
+ } else {
+ // Add extra paranoid. This will be statically
+ // removed at compile time unless a race build.
+ es.values.Remove(seg)
+ es.values.Add(r, existing)
+ seg = es.values.LowerBoundSegment(addr)
+ }
+
+ // Compute the traversal required & update references.
+ dots := traverse(obj.Type(), old.Type(), addr, seg.Start())
+ wt := es.findType(obj.Type())
+ for _, ref := range existing.refs {
+ ref.Dots = append(ref.Dots, dots...)
+ ref.Type = wt
+ }
+ default:
+ // There is a non-sensical overlap.
+ Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj)
+ }
+
+ // Compute the new reference, record and return it.
+ ref.Root = wire.Uint(existing.id)
+ ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr)
+ ref.Type = es.findType(obj.Type())
+ existing.refs = append(existing.refs, ref)
+ return
+ }
+
+ // The only remaining case is a pointer value that doesn't overlap with
+ // any registered addresses. Create a new entry for it, and start
+ // tracking the first reference we just created.
+ oes := &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ }
+ if !raceEnabled {
+ es.values.AddWithoutMerging(r, oes)
+ } else {
+ // Merges should never happen. This is just enabled extra
+ // sanity checks because the Merge function below will panic.
+ es.values.Add(r, oes)
+ }
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+ ref.Root = wire.Uint(oes.id)
+ oes.refs = append(oes.refs, ref)
+}
+
+// traverse searches for a target object within a root object, where the target
+// object is a struct field or array element within root, with potentially
+// multiple intervening types. traverse returns the set of field or element
+// traversals required to reach the target.
+//
+// Note that for efficiency, traverse returns the dots in the reverse order.
+// That is, the first traversal required will be the last element of the list.
+//
+// Precondition: The target object must lie completely within the range defined
+// by [rootAddr, rootAddr + sizeof(rootType)].
+func traverse(rootType, targetType reflect.Type, rootAddr, targetAddr uintptr) []wire.Dot {
+ // Recursion base case: the types actually match.
+ if targetType == rootType && targetAddr == rootAddr {
+ return nil
+ }
+
+ switch rootType.Kind() {
+ case reflect.Struct:
+ offset := targetAddr - rootAddr
+ for i := rootType.NumField(); i > 0; i-- {
+ field := rootType.Field(i - 1)
+ // The first field from the end with an offset that is
+ // smaller than or equal to our address offset is where
+ // the target is located. Traverse from there.
+ if field.Offset <= offset {
+ dots := traverse(field.Type, targetType, rootAddr+field.Offset, targetAddr)
+ fieldName := wire.FieldName(field.Name)
+ return append(dots, &fieldName)
+ }
+ }
+ // Should never happen; the target should be reachable.
+ Failf("no field in root type %v contains target type %v", rootType, targetType)
+
+ case reflect.Array:
+ // Since arrays have homogenous types, all elements have the
+ // same size and we can compute where the target lives. This
+ // does not matter for the purpose of typing, but matters for
+ // the purpose of computing the address of the given index.
+ elemSize := int(rootType.Elem().Size())
+ n := int(targetAddr-rootAddr) / elemSize // Relies on integer division rounding down.
+ if rootType.Len() < n {
+ Failf("traversal target of type %v @%x is beyond the end of the array type %v @%x with %v elements",
+ targetType, targetAddr, rootType, rootAddr, rootType.Len())
+ }
+ dots := traverse(rootType.Elem(), targetType, rootAddr+uintptr(n*elemSize), targetAddr)
+ return append(dots, wire.Index(n))
+
+ default:
+ // For any other type, there's no possibility of aliasing so if
+ // the types didn't match earlier then we have an addresss
+ // collision which shouldn't be possible at this point.
+ Failf("traverse failed for root type %v and target type %v", rootType, targetType)
+ }
+ panic("unreachable")
+}
+
+// encodeMap encodes a map.
+func (es *encodeState) encodeMap(obj reflect.Value, dest *wire.Object) {
+ if obj.IsNil() {
+ // Because there is a difference between a nil map and an empty
+ // map, we need to not decode in the case of a truly nil map.
+ *dest = wire.Nil{}
+ return
+ }
+ l := obj.Len()
+ m := &wire.Map{
+ Keys: make([]wire.Object, l),
+ Values: make([]wire.Object, l),
+ }
+ *dest = m
+ for i, k := range obj.MapKeys() {
+ v := obj.MapIndex(k)
+ // Map keys must be encoded using the full value because the
+ // type will be omitted after the first key.
+ es.encodeObject(k, encodeAsValue, &m.Keys[i])
+ es.encodeObject(v, encodeAsValue, &m.Values[i])
+ }
+}
+
+// objectEncoder is for encoding structs.
+type objectEncoder struct {
+ // es is encodeState.
+ es *encodeState
+
+ // encoded is the encoded struct.
+ encoded *wire.Struct
+}
+
+// save is called by the public methods on Sink.
+func (oe *objectEncoder) save(slot int, obj reflect.Value) {
+ fieldValue := oe.encoded.Field(slot)
+ oe.es.encodeObject(obj, encodeDefault, fieldValue)
+}
+
+// encodeStruct encodes a composite object.
+func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) {
+ // Ensure that the obj is addressable. There are two cases when it is
+ // not. First, is when this is dispatched via SaveValue. Second, when
+ // this is a map key as a struct. Either way, we need to make a copy to
+ // obtain an addressable value.
+ if !obj.CanAddr() {
+ localObj := reflect.New(obj.Type())
+ localObj.Elem().Set(obj)
+ obj = localObj.Elem()
+ }
+
+ // Prepare the value.
+ s := &wire.Struct{}
+ *dest = s
+
+ // Look the type up in the database.
+ te, ok := es.types.Lookup(obj.Type())
+ if te == nil {
+ if obj.NumField() == 0 {
+ // Allow unregistered anonymous, empty structs. This
+ // will just return success without ever invoking the
+ // passed function. This uses the immutable EmptyStruct
+ // variable to prevent an allocation in this case.
+ //
+ // Note that this mechanism does *not* work for
+ // interfaces in general. So you can't dispatch
+ // non-registered empty structs via interfaces because
+ // then they can't be restored.
+ s.Alloc(0)
+ return
+ }
+ // We need a SaverLoader for struct types.
+ Failf("struct %T does not implement SaverLoader", obj.Interface())
+ }
+ if !ok {
+ // Queue the type to be serialized.
+ es.pendingTypes = append(es.pendingTypes, te.Type)
+ }
+
+ // Invoke the provided saver.
+ s.TypeID = wire.TypeID(te.ID)
+ s.Alloc(len(te.Fields))
+ oe := objectEncoder{
+ es: es,
+ encoded: s,
+ }
+ es.stats.start(te.ID)
+ defer es.stats.done()
+ if sl, ok := obj.Addr().Interface().(SaverLoader); ok {
+ // Note: may be a registered empty struct which does not
+ // implement the saver/loader interfaces.
+ sl.StateSave(Sink{internal: oe})
+ }
+}
+
+// encodeArray encodes an array.
+func (es *encodeState) encodeArray(obj reflect.Value, dest *wire.Object) {
+ l := obj.Len()
+ a := &wire.Array{
+ Contents: make([]wire.Object, l),
+ }
+ *dest = a
+ for i := 0; i < l; i++ {
+ // We need to encode the full value because arrays are encoded
+ // using the type information from only the first element.
+ es.encodeObject(obj.Index(i), encodeAsValue, &a.Contents[i])
+ }
+}
+
+// findType recursively finds type information.
+func (es *encodeState) findType(typ reflect.Type) wire.TypeSpec {
+ // First: check if this is a proper type. It's possible for pointers,
+ // slices, arrays, maps, etc to all have some different type.
+ te, ok := es.types.Lookup(typ)
+ if te != nil {
+ if !ok {
+ // See encodeStruct.
+ es.pendingTypes = append(es.pendingTypes, te.Type)
+ }
+ return wire.TypeID(te.ID)
+ }
+
+ switch typ.Kind() {
+ case reflect.Ptr:
+ return &wire.TypeSpecPointer{
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Slice:
+ return &wire.TypeSpecSlice{
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Array:
+ return &wire.TypeSpecArray{
+ Count: wire.Uint(typ.Len()),
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Map:
+ return &wire.TypeSpecMap{
+ Key: es.findType(typ.Key()),
+ Value: es.findType(typ.Elem()),
+ }
+ default:
+ // After potentially chasing many pointers, the
+ // ultimate type of the object is not known.
+ Failf("type %q is not known", typ)
+ }
+ panic("unreachable")
+}
+
+// encodeInterface encodes an interface.
+func (es *encodeState) encodeInterface(obj reflect.Value, dest *wire.Object) {
+ // Dereference the object.
+ obj = obj.Elem()
+ if !obj.IsValid() {
+ // Special case: the nil object.
+ *dest = &wire.Interface{
+ Type: wire.TypeSpecNil{},
+ Value: wire.Nil{},
+ }
+ return
+ }
+
+ // Encode underlying object.
+ i := &wire.Interface{
+ Type: es.findType(obj.Type()),
+ }
+ *dest = i
+ es.encodeObject(obj, encodeAsValue, &i.Value)
+}
+
+// isPrimitive returns true if this is a primitive object, or a composite
+// object composed entirely of primitives.
+func isPrimitiveZero(typ reflect.Type) bool {
+ switch typ.Kind() {
+ case reflect.Ptr:
+ // Pointers are always treated as primitive types because we
+ // won't encode directly from here. Returning true here won't
+ // prevent the object from being encoded correctly.
+ return true
+ case reflect.Bool:
+ return true
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return true
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ return true
+ case reflect.Float32, reflect.Float64:
+ return true
+ case reflect.Complex64, reflect.Complex128:
+ return true
+ case reflect.String:
+ return true
+ case reflect.Slice:
+ // The slice itself a primitive, but not necessarily the array
+ // that points to. This is similar to a pointer.
+ return true
+ case reflect.Array:
+ // We cannot treat an array as a primitive, because it may be
+ // composed of structures or other things with side-effects.
+ return isPrimitiveZero(typ.Elem())
+ case reflect.Interface:
+ // Since we now that this type is the zero type, the interface
+ // value must be zero. Therefore this is primitive.
+ return true
+ case reflect.Struct:
+ return false
+ case reflect.Map:
+ // The isPrimitiveZero function is called only on zero-types to
+ // see if it's safe to serialize. Since a zero map has no
+ // elements, it is safe to treat as a primitive.
+ return true
+ default:
+ Failf("unknown type %q", typ.Name())
+ }
+ panic("unreachable")
+}
+
+// encodeStrategy is the strategy used for encodeObject.
+type encodeStrategy int
+
+const (
+ // encodeDefault means types are encoded normally as references.
+ encodeDefault encodeStrategy = iota
+
+ // encodeAsValue means that types will never take short-circuited and
+ // will always be encoded as a normal value.
+ encodeAsValue
+
+ // encodeMapAsValue means that even maps will be fully encoded.
+ encodeMapAsValue
+)
+
+// encodeObject encodes an object.
+func (es *encodeState) encodeObject(obj reflect.Value, how encodeStrategy, dest *wire.Object) {
+ if how == encodeDefault && isPrimitiveZero(obj.Type()) && obj.IsZero() {
+ *dest = wire.Nil{}
+ return
+ }
+ switch obj.Kind() {
+ case reflect.Ptr: // Fast path: first.
+ r := new(wire.Ref)
+ *dest = r
+ if obj.IsNil() {
+ // May be in an array or elsewhere such that a value is
+ // required. So we encode as a reference to the zero
+ // object, which does not exist. Note that this has to
+ // be handled correctly in the decode path as well.
+ return
+ }
+ es.resolve(obj, r)
+ case reflect.Bool:
+ *dest = wire.Bool(obj.Bool())
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ *dest = wire.Int(obj.Int())
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ *dest = wire.Uint(obj.Uint())
+ case reflect.Float32:
+ *dest = wire.Float32(obj.Float())
+ case reflect.Float64:
+ *dest = wire.Float64(obj.Float())
+ case reflect.Complex64:
+ c := wire.Complex64(obj.Complex())
+ *dest = &c // Needs alloc.
+ case reflect.Complex128:
+ c := wire.Complex128(obj.Complex())
+ *dest = &c // Needs alloc.
+ case reflect.String:
+ s := wire.String(obj.String())
+ *dest = &s // Needs alloc.
+ case reflect.Array:
+ es.encodeArray(obj, dest)
+ case reflect.Slice:
+ s := &wire.Slice{
+ Capacity: wire.Uint(obj.Cap()),
+ Length: wire.Uint(obj.Len()),
+ }
+ *dest = s
+ // Note that we do need to provide a wire.Slice type here as
+ // how is not encodeDefault. If this were the case, then it
+ // would have been caught by the IsZero check above and we
+ // would have just used wire.Nil{}.
+ if obj.IsNil() {
+ return
+ }
+ // Slices need pointer resolution.
+ es.resolve(arrayFromSlice(obj), &s.Ref)
+ case reflect.Interface:
+ es.encodeInterface(obj, dest)
+ case reflect.Struct:
+ es.encodeStruct(obj, dest)
+ case reflect.Map:
+ if how == encodeMapAsValue {
+ es.encodeMap(obj, dest)
+ return
+ }
+ r := new(wire.Ref)
+ *dest = r
+ es.resolve(obj, r)
+ default:
+ Failf("unknown object %#v", obj.Interface())
+ panic("unreachable")
+ }
+}
+
+// Save serializes the object graph rooted at obj.
+func (es *encodeState) Save(obj reflect.Value) {
+ es.stats.init()
+ defer es.stats.fini(func(id typeID) string {
+ return es.pendingTypes[id-1].Name
+ })
+
+ // Resolve the first object, which should queue a pile of additional
+ // objects on the pending list. All queued objects should be fully
+ // resolved, and we should be able to serialize after this call.
+ var root wire.Ref
+ es.resolve(obj.Addr(), &root)
+
+ // Encode the graph.
+ var oes *objectEncodeState
+ if err := safely(func() {
+ for oes = es.deferred.Front(); oes != nil; oes = es.deferred.Front() {
+ // Remove and encode the object. Note that as a result
+ // of this encoding, the object may be enqueued on the
+ // deferred list yet again. That's expected, and why it
+ // is removed first.
+ es.deferred.Remove(oes)
+ es.encodeObject(oes.obj, oes.how, &oes.encoded)
+ }
+ }); err != nil {
+ // Include the object in the error message.
+ Failf("encoding error at object %#v: %w", oes.obj.Interface(), err)
+ }
+
+ // Check that items are pending.
+ if es.pending.Front() == nil {
+ Failf("pending is empty?")
+ }
+
+ // Write the header with the number of objects. Note that there is no
+ // way that es.lastID could conflict with objectID, which would
+ // indicate that an impossibly large encoding.
+ if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil {
+ Failf("error writing header: %w", err)
+ }
+
+ // Serialize all pending types and pending objects. Note that we don't
+ // bother removing from this list as we walk it because that just
+ // wastes time. It will not change after this point.
+ var id objectID
+ if err := safely(func() {
+ for _, wt := range es.pendingTypes {
+ // Encode the type.
+ wire.Save(es.w, &wt)
+ }
+ for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() {
+ id++ // First object is 1.
+ if oes.id != id {
+ Failf("expected id %d, got %d", id, oes.id)
+ }
+
+ // Marshall the object.
+ wire.Save(es.w, oes.encoded)
+ }
+ }); err != nil {
+ // Include the object and the error.
+ Failf("error serializing object %#v: %w", oes.encoded, err)
+ }
+
+ // Check what we wrote.
+ if id != es.lastID {
+ Failf("expected %d objects, wrote %d", es.lastID, id)
+ }
+}
+
+// objectFlag indicates that the length is a # of objects, rather than a raw
+// byte length. When this is set on a length header in the stream, it may be
+// decoded appropriately.
+const objectFlag uint64 = 1 << 63
+
+// WriteHeader writes a header.
+//
+// Each object written to the statefile should be prefixed with a header. In
+// order to generate statefiles that play nicely with debugging tools, raw
+// writes should be prefixed with a header with object set to false and the
+// appropriate length. This will allow tools to skip these regions.
+func WriteHeader(w wire.Writer, length uint64, object bool) error {
+ // Sanity check the length.
+ if length&objectFlag != 0 {
+ Failf("impossibly huge length: %d", length)
+ }
+ if object {
+ length |= objectFlag
+ }
+
+ // Write a header.
+ return safely(func() {
+ wire.SaveUint(w, length)
+ })
+}
+
+// pendingMapper is for the pending list.
+type pendingMapper struct{}
+
+func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry }
+
+// deferredMapper is for the deferred list.
+type deferredMapper struct{}
+
+func (deferredMapper) linkerFor(oes *objectEncodeState) *deferredEntry { return &oes.deferredEntry }
+
+// addrSetFunctions is used by addrSet.
+type addrSetFunctions struct{}
+
+func (addrSetFunctions) MinKey() uintptr {
+ return 0
+}
+
+func (addrSetFunctions) MaxKey() uintptr {
+ return ^uintptr(0)
+}
+
+func (addrSetFunctions) ClearValue(val **objectEncodeState) {
+ *val = nil
+}
+
+func (addrSetFunctions) Merge(r1 addrRange, val1 *objectEncodeState, r2 addrRange, val2 *objectEncodeState) (*objectEncodeState, bool) {
+ if val1.obj == val2.obj {
+ // This, should never happen. It would indicate that the same
+ // object exists in two non-contiguous address ranges. Note
+ // that this assertion can only be triggered if the race
+ // detector is enabled.
+ Failf("unexpected merge in addrSet @ %v and %v: %#v and %#v", r1, r2, val1.obj, val2.obj)
+ }
+ // Reject the merge.
+ return val1, false
+}
+
+func (addrSetFunctions) Split(r addrRange, val *objectEncodeState, _ uintptr) (*objectEncodeState, *objectEncodeState) {
+ // A split should never happen: we don't remove ranges.
+ Failf("unexpected split in addrSet @ %v: %#v", r, val.obj)
+ panic("unreachable")
+}
diff --git a/pkg/state/encode_unsafe.go b/pkg/state/encode_unsafe.go
new file mode 100644
index 000000000..e0dad83b4
--- /dev/null
+++ b/pkg/state/encode_unsafe.go
@@ -0,0 +1,33 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// arrayFromSlice constructs a new pointer to the slice data.
+//
+// It would be similar to the following:
+//
+// x := make([]Foo, l, c)
+// a := ([l]Foo*)(unsafe.Pointer(x[0]))
+//
+func arrayFromSlice(obj reflect.Value) reflect.Value {
+ return reflect.NewAt(
+ reflect.ArrayOf(obj.Cap(), obj.Type().Elem()),
+ unsafe.Pointer(obj.Pointer()))
+}
diff --git a/pkg/state/pretty/BUILD b/pkg/state/pretty/BUILD
new file mode 100644
index 000000000..d053802f7
--- /dev/null
+++ b/pkg/state/pretty/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pretty",
+ srcs = ["pretty.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/state",
+ "//pkg/state/wire",
+ ],
+)
diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go
new file mode 100644
index 000000000..cf37aaa49
--- /dev/null
+++ b/pkg/state/pretty/pretty.go
@@ -0,0 +1,273 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pretty is a pretty-printer for state streams.
+package pretty
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "reflect"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+func formatRef(x *wire.Ref, graph uint64, html bool) string {
+ baseRef := fmt.Sprintf("g%dr%d", graph, x.Root)
+ fullRef := baseRef
+ if len(x.Dots) > 0 {
+ // See wire.Ref; Type valid if Dots non-zero.
+ typ, _ := formatType(x.Type, graph, html)
+ var buf strings.Builder
+ buf.WriteString("(*")
+ buf.WriteString(typ)
+ buf.WriteString(")(")
+ buf.WriteString(baseRef)
+ for _, component := range x.Dots {
+ switch v := component.(type) {
+ case *wire.FieldName:
+ buf.WriteString(".")
+ buf.WriteString(string(*v))
+ case wire.Index:
+ buf.WriteString(fmt.Sprintf("[%d]", v))
+ default:
+ panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component)))
+ }
+ }
+ buf.WriteString(")")
+ fullRef = buf.String()
+ }
+ if html {
+ return fmt.Sprintf("<a href=\"#%s\">%s</a>", baseRef, fullRef)
+ }
+ return fullRef
+}
+
+func formatType(t wire.TypeSpec, graph uint64, html bool) (string, bool) {
+ switch x := t.(type) {
+ case wire.TypeID:
+ base := fmt.Sprintf("g%dt%d", graph, x)
+ if html {
+ return fmt.Sprintf("<a href=\"#%s\">%s</a>", base, base), true
+ }
+ return fmt.Sprintf("%s", base), true
+ case wire.TypeSpecNil:
+ return "", false // Only nil type.
+ case *wire.TypeSpecPointer:
+ element, _ := formatType(x.Type, graph, html)
+ return fmt.Sprintf("(*%s)", element), true
+ case *wire.TypeSpecArray:
+ element, _ := formatType(x.Type, graph, html)
+ return fmt.Sprintf("[%d](%s)", x.Count, element), true
+ case *wire.TypeSpecSlice:
+ element, _ := formatType(x.Type, graph, html)
+ return fmt.Sprintf("([]%s)", element), true
+ case *wire.TypeSpecMap:
+ key, _ := formatType(x.Key, graph, html)
+ value, _ := formatType(x.Value, graph, html)
+ return fmt.Sprintf("(map[%s]%s)", key, value), true
+ default:
+ panic(fmt.Sprintf("unreachable: unknown type %T", t))
+ }
+}
+
+// format formats a single object, for pretty-printing. It also returns whether
+// the value is a non-zero value.
+func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bool) {
+ switch x := encoded.(type) {
+ case wire.Nil:
+ return "nil", false
+ case *wire.String:
+ return fmt.Sprintf("%q", *x), *x != ""
+ case *wire.Complex64:
+ return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0
+ case *wire.Complex128:
+ return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0
+ case *wire.Ref:
+ return formatRef(x, graph, html), x.Root != 0
+ case *wire.Type:
+ tabs := "\n" + strings.Repeat("\t", depth)
+ items := make([]string, 0, len(x.Fields)+2)
+ items = append(items, fmt.Sprintf("type %s {", x.Name))
+ for i := 0; i < len(x.Fields); i++ {
+ items = append(items, fmt.Sprintf("\t%d: %s,", i, x.Fields[i]))
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), true // No zero value.
+ case *wire.Slice:
+ return fmt.Sprintf("%s{len:%d,cap:%d}", formatRef(&x.Ref, graph, html), x.Length, x.Capacity), x.Capacity != 0
+ case *wire.Array:
+ if len(x.Contents) == 0 {
+ return "[]", false
+ }
+ items := make([]string, 0, len(x.Contents)+2)
+ zeros := make([]string, 0) // used to eliminate zero entries.
+ items = append(items, "[")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ for i := 0; i < len(x.Contents); i++ {
+ item, ok := format(graph, depth+1, x.Contents[i], html)
+ if !ok {
+ zeros = append(zeros, fmt.Sprintf("\t%s,", item))
+ continue
+ }
+ if len(zeros) > 0 {
+ items = append(items, zeros...)
+ zeros = nil
+ }
+ items = append(items, fmt.Sprintf("\t%s,", item))
+ }
+ if len(zeros) > 0 {
+ items = append(items, fmt.Sprintf("\t... (%d zeros),", len(zeros)))
+ }
+ items = append(items, "]")
+ return strings.Join(items, tabs), len(zeros) < len(x.Contents)
+ case *wire.Struct:
+ typ, _ := formatType(x.TypeID, graph, html)
+ if x.Fields() == 0 {
+ return fmt.Sprintf("struct[%s]{}", typ), false
+ }
+ items := make([]string, 0, 2)
+ items = append(items, fmt.Sprintf("struct[%s]{", typ))
+ tabs := "\n" + strings.Repeat("\t", depth)
+ allZero := true
+ for i := 0; i < x.Fields(); i++ {
+ element, ok := format(graph, depth+1, *x.Field(i), html)
+ allZero = allZero && !ok
+ items = append(items, fmt.Sprintf("\t%d: %s,", i, element))
+ i++
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), !allZero
+ case *wire.Map:
+ if len(x.Keys) == 0 {
+ return "map{}", false
+ }
+ items := make([]string, 0, len(x.Keys)+2)
+ items = append(items, "map{")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ for i := 0; i < len(x.Keys); i++ {
+ key, _ := format(graph, depth+1, x.Keys[i], html)
+ value, _ := format(graph, depth+1, x.Values[i], html)
+ items = append(items, fmt.Sprintf("\t%s: %s,", key, value))
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), true
+ case *wire.Interface:
+ typ, typOk := formatType(x.Type, graph, html)
+ element, elementOk := format(graph, depth+1, x.Value, html)
+ return fmt.Sprintf("interface[%s]{%s}", typ, element), typOk || elementOk
+ default:
+ // Must be a primitive; use reflection.
+ return fmt.Sprintf("%v", encoded), true
+ }
+}
+
+// printStream is the basic print implementation.
+func printStream(w io.Writer, r wire.Reader, html bool) (err error) {
+ // current graph ID.
+ var graph uint64
+
+ if html {
+ fmt.Fprintf(w, "<pre>")
+ defer fmt.Fprintf(w, "</pre>")
+ }
+
+ defer func() {
+ if r := recover(); r != nil {
+ if rErr, ok := r.(error); ok {
+ err = rErr // Override return.
+ return
+ }
+ panic(r) // Propagate.
+ }
+ }()
+
+ for {
+ // Find the first object to begin generation.
+ length, object, err := state.ReadHeader(r)
+ if err == io.EOF {
+ // Nothing else to do.
+ break
+ } else if err != nil {
+ return err
+ }
+ if !object {
+ graph++ // Increment the graph.
+ if length > 0 {
+ fmt.Fprintf(w, "(%d bytes non-object data)\n", length)
+ io.Copy(ioutil.Discard, &io.LimitedReader{
+ R: r,
+ N: int64(length),
+ })
+ }
+ continue
+ }
+
+ // Read & unmarshal the object.
+ //
+ // Note that this loop must match the general structure of the
+ // loop in decode.go. But we don't register type information,
+ // etc. and just print the raw structures.
+ var (
+ oid uint64 = 1
+ tid uint64 = 1
+ )
+ for oid <= length {
+ // Unmarshal the object.
+ encoded := wire.Load(r)
+
+ // Is this a type?
+ if _, ok := encoded.(*wire.Type); ok {
+ str, _ := format(graph, 0, encoded, html)
+ tag := fmt.Sprintf("g%dt%d", graph, tid)
+ if html {
+ // See below.
+ tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
+ }
+ if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
+ return err
+ }
+ tid++
+ continue
+ }
+
+ // Format the node.
+ str, _ := format(graph, 0, encoded, html)
+ tag := fmt.Sprintf("g%dr%d", graph, oid)
+ if html {
+ // Create a little tag with an anchor next to it for linking.
+ tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
+ }
+ if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
+ return err
+ }
+ oid++
+ }
+ }
+
+ return nil
+}
+
+// PrintText reads the stream from r and prints text to w.
+func PrintText(w io.Writer, r wire.Reader) error {
+ return printStream(w, r, false /* html */)
+}
+
+// PrintHTML reads the stream from r and prints html to w.
+func PrintHTML(w io.Writer, r wire.Reader) error {
+ return printStream(w, r, true /* html */)
+}
diff --git a/pkg/state/state.go b/pkg/state/state.go
new file mode 100644
index 000000000..acb629969
--- /dev/null
+++ b/pkg/state/state.go
@@ -0,0 +1,321 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package state provides functionality related to saving and loading object
+// graphs. For most types, it provides a set of default saving / loading logic
+// that will be invoked automatically if custom logic is not defined.
+//
+// Kind Support
+// ---- -------
+// Bool default
+// Int default
+// Int8 default
+// Int16 default
+// Int32 default
+// Int64 default
+// Uint default
+// Uint8 default
+// Uint16 default
+// Uint32 default
+// Uint64 default
+// Float32 default
+// Float64 default
+// Complex64 default
+// Complex128 default
+// Array default
+// Chan custom
+// Func custom
+// Interface default
+// Map default
+// Ptr default
+// Slice default
+// String default
+// Struct custom (*) Unless zero-sized.
+// UnsafePointer custom
+//
+// See README.md for an overview of how encoding and decoding works.
+package state
+
+import (
+ "context"
+ "fmt"
+ "reflect"
+ "runtime"
+
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// objectID is a unique identifier assigned to each object to be serialized.
+// Each instance of an object is considered separately, i.e. if there are two
+// objects of the same type in the object graph being serialized, they'll be
+// assigned unique objectIDs.
+type objectID uint32
+
+// typeID is the identifier for a type. Types are serialized and tracked
+// alongside objects in order to avoid the overhead of encoding field names in
+// all objects.
+type typeID uint32
+
+// ErrState is returned when an error is encountered during encode/decode.
+type ErrState struct {
+ // err is the underlying error.
+ err error
+
+ // trace is the stack trace.
+ trace string
+}
+
+// Error returns a sensible description of the state error.
+func (e *ErrState) Error() string {
+ return fmt.Sprintf("%v:\n%s", e.err, e.trace)
+}
+
+// Unwrap implements standard unwrapping.
+func (e *ErrState) Unwrap() error {
+ return e.err
+}
+
+// Save saves the given object state.
+func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) {
+ // Create the encoding state.
+ es := encodeState{
+ ctx: ctx,
+ w: w,
+ types: makeTypeEncodeDatabase(),
+ zeroValues: make(map[reflect.Type]*objectEncodeState),
+ }
+
+ // Perform the encoding.
+ err := safely(func() {
+ es.Save(reflect.ValueOf(rootPtr).Elem())
+ })
+ return es.stats, err
+}
+
+// Load loads a checkpoint.
+func Load(ctx context.Context, r wire.Reader, rootPtr interface{}) (Stats, error) {
+ // Create the decoding state.
+ ds := decodeState{
+ ctx: ctx,
+ r: r,
+ types: makeTypeDecodeDatabase(),
+ deferred: make(map[objectID]wire.Object),
+ }
+
+ // Attempt our decode.
+ err := safely(func() {
+ ds.Load(reflect.ValueOf(rootPtr).Elem())
+ })
+ return ds.stats, err
+}
+
+// Sink is used for Type.StateSave.
+type Sink struct {
+ internal objectEncoder
+}
+
+// Save adds the given object to the map.
+//
+// You should pass always pointers to the object you are saving. For example:
+//
+// type X struct {
+// A int
+// B *int
+// }
+//
+// func (x *X) StateTypeInfo(m Sink) state.TypeInfo {
+// return state.TypeInfo{
+// Name: "pkg.X",
+// Fields: []string{
+// "A",
+// "B",
+// },
+// }
+// }
+//
+// func (x *X) StateSave(m Sink) {
+// m.Save(0, &x.A) // Field is A.
+// m.Save(1, &x.B) // Field is B.
+// }
+//
+// func (x *X) StateLoad(m Source) {
+// m.Load(0, &x.A) // Field is A.
+// m.Load(1, &x.B) // Field is B.
+// }
+func (s Sink) Save(slot int, objPtr interface{}) {
+ s.internal.save(slot, reflect.ValueOf(objPtr).Elem())
+}
+
+// SaveValue adds the given object value to the map.
+//
+// This should be used for values where pointers are not available, or casts
+// are required during Save/Load.
+//
+// For example, if we want to cast external package type P.Foo to int64:
+//
+// func (x *X) StateSave(m Sink) {
+// m.SaveValue(0, "A", int64(x.A))
+// }
+//
+// func (x *X) StateLoad(m Source) {
+// m.LoadValue(0, new(int64), func(x interface{}) {
+// x.A = P.Foo(x.(int64))
+// })
+// }
+func (s Sink) SaveValue(slot int, obj interface{}) {
+ s.internal.save(slot, reflect.ValueOf(obj))
+}
+
+// Context returns the context object provided at save time.
+func (s Sink) Context() context.Context {
+ return s.internal.es.ctx
+}
+
+// Type is an interface that must be implemented by Struct objects. This allows
+// these objects to be serialized while minimizing runtime reflection required.
+//
+// All these methods can be automatically generated by the go_statify tool.
+type Type interface {
+ // StateTypeName returns the type's name.
+ //
+ // This is used for matching type information during encoding and
+ // decoding, as well as dynamic interface dispatch. This should be
+ // globally unique.
+ StateTypeName() string
+
+ // StateFields returns information about the type.
+ //
+ // Fields is the set of fields for the object. Calls to Sink.Save and
+ // Source.Load must be made in-order with respect to these fields.
+ //
+ // This will be called at most once per serialization.
+ StateFields() []string
+}
+
+// SaverLoader must be implemented by struct types.
+type SaverLoader interface {
+ // StateSave saves the state of the object to the given Map.
+ StateSave(Sink)
+
+ // StateLoad loads the state of the object.
+ StateLoad(Source)
+}
+
+// Source is used for Type.StateLoad.
+type Source struct {
+ internal objectDecoder
+}
+
+// Load loads the given object passed as a pointer..
+//
+// See Sink.Save for an example.
+func (s Source) Load(slot int, objPtr interface{}) {
+ s.internal.load(slot, reflect.ValueOf(objPtr), false, nil)
+}
+
+// LoadWait loads the given objects from the map, and marks it as requiring all
+// AfterLoad executions to complete prior to running this object's AfterLoad.
+//
+// See Sink.Save for an example.
+func (s Source) LoadWait(slot int, objPtr interface{}) {
+ s.internal.load(slot, reflect.ValueOf(objPtr), true, nil)
+}
+
+// LoadValue loads the given object value from the map.
+//
+// See Sink.SaveValue for an example.
+func (s Source) LoadValue(slot int, objPtr interface{}, fn func(interface{})) {
+ o := reflect.ValueOf(objPtr)
+ s.internal.load(slot, o, true, func() { fn(o.Elem().Interface()) })
+}
+
+// AfterLoad schedules a function execution when all objects have been
+// allocated and their automated loading and customized load logic have been
+// executed. fn will not be executed until all of current object's
+// dependencies' AfterLoad() logic, if exist, have been executed.
+func (s Source) AfterLoad(fn func()) {
+ s.internal.afterLoad(fn)
+}
+
+// Context returns the context object provided at load time.
+func (s Source) Context() context.Context {
+ return s.internal.ds.ctx
+}
+
+// IsZeroValue checks if the given value is the zero value.
+//
+// This function is used by the stateify tool.
+func IsZeroValue(val interface{}) bool {
+ return val == nil || reflect.ValueOf(val).Elem().IsZero()
+}
+
+// Failf is a wrapper around panic that should be used to generate errors that
+// can be caught during saving and loading.
+func Failf(fmtStr string, v ...interface{}) {
+ panic(fmt.Errorf(fmtStr, v...))
+}
+
+// safely executes the given function, catching a panic and unpacking as an
+// error.
+//
+// The error flow through the state package uses panic and recover. There are
+// two important reasons for this:
+//
+// 1) Many of the reflection methods will already panic with invalid data or
+// violated assumptions. We would want to recover anyways here.
+//
+// 2) It allows us to eliminate boilerplate within Save() and Load() functions.
+// In nearly all cases, when the low-level serialization functions fail, you
+// will want the checkpoint to fail anyways. Plumbing errors through every
+// method doesn't add a lot of value. If there are specific error conditions
+// that you'd like to handle, you should add appropriate functionality to
+// objects themselves prior to calling Save() and Load().
+func safely(fn func()) (err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if es, ok := r.(*ErrState); ok {
+ err = es // Propagate.
+ return
+ }
+
+ // Build a new state error.
+ es := new(ErrState)
+ if e, ok := r.(error); ok {
+ es.err = e
+ } else {
+ es.err = fmt.Errorf("%v", r)
+ }
+
+ // Make a stack. We don't know how big it will be ahead
+ // of time, but want to make sure we get the whole
+ // thing. So we just do a stupid brute force approach.
+ var stack []byte
+ for sz := 1024; ; sz *= 2 {
+ stack = make([]byte, sz)
+ n := runtime.Stack(stack, false)
+ if n < sz {
+ es.trace = string(stack[:n])
+ break
+ }
+ }
+
+ // Set the error.
+ err = es
+ }
+ }()
+
+ // Execute the function.
+ fn()
+ return nil
+}
diff --git a/pkg/state/state_norace.go b/pkg/state/state_norace.go
new file mode 100644
index 000000000..4281aed6d
--- /dev/null
+++ b/pkg/state/state_norace.go
@@ -0,0 +1,19 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !race
+
+package state
+
+var raceEnabled = false
diff --git a/pkg/state/state_race.go b/pkg/state/state_race.go
new file mode 100644
index 000000000..8232981ce
--- /dev/null
+++ b/pkg/state/state_race.go
@@ -0,0 +1,19 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build race
+
+package state
+
+var raceEnabled = true
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD
new file mode 100644
index 000000000..d6c89c7e9
--- /dev/null
+++ b/pkg/state/statefile/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "statefile",
+ srcs = ["statefile.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/binary",
+ "//pkg/compressio",
+ "//pkg/state/wire",
+ ],
+)
+
+go_test(
+ name = "statefile_test",
+ size = "small",
+ srcs = ["statefile_test.go"],
+ library = ":statefile",
+ deps = ["//pkg/compressio"],
+)
diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go
new file mode 100644
index 000000000..bdfb800fb
--- /dev/null
+++ b/pkg/state/statefile/statefile.go
@@ -0,0 +1,239 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package statefile defines the state file data stream.
+//
+// This package currently does not include any details regarding the state
+// encoding itself, only details regarding state metadata and data layout.
+//
+// The file format is defined as follows.
+//
+// /------------------------------------------------------\
+// | header (8-bytes) |
+// +------------------------------------------------------+
+// | metadata length (8-bytes) |
+// +------------------------------------------------------+
+// | metadata |
+// +------------------------------------------------------+
+// | data |
+// \------------------------------------------------------/
+//
+// First, it includes a 8-byte magic header which is the following
+// sequence of bytes [0x67, 0x56, 0x69, 0x73, 0x6f, 0x72, 0x53, 0x46]
+//
+// This header is followed by an 8-byte length N (big endian), and an
+// ASCII-encoded JSON map that is exactly N bytes long.
+//
+// This map includes only strings for keys and strings for values. Keys in the
+// map that begin with "_" are for internal use only. They may be read, but may
+// not be provided by the user. In the future, this metadata may contain some
+// information relating to the state encoding itself.
+//
+// After the map, the remainder of the file is the state data.
+package statefile
+
+import (
+ "bytes"
+ "compress/flate"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/json"
+ "fmt"
+ "hash"
+ "io"
+ "strings"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/compressio"
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// keySize is the AES-256 key length.
+const keySize = 32
+
+// compressionChunkSize is the chunk size for compression.
+const compressionChunkSize = 1024 * 1024
+
+// maxMetadataSize is the size limit of metadata section.
+const maxMetadataSize = 16 * 1024 * 1024
+
+// magicHeader is the byte sequence beginning each file.
+var magicHeader = []byte("\x67\x56\x69\x73\x6f\x72\x53\x46")
+
+// ErrBadMagic is returned if the header does not match.
+var ErrBadMagic = fmt.Errorf("bad magic header")
+
+// ErrMetadataMissing is returned if the state file is missing mandatory metadata.
+var ErrMetadataMissing = fmt.Errorf("missing metadata")
+
+// ErrInvalidMetadataLength is returned if the metadata length is too large.
+var ErrInvalidMetadataLength = fmt.Errorf("metadata length invalid, maximum size is %d", maxMetadataSize)
+
+// ErrMetadataInvalid is returned if passed metadata is invalid.
+var ErrMetadataInvalid = fmt.Errorf("metadata invalid, can't start with _")
+
+// WriteCloser is an io.Closer and wire.Writer.
+type WriteCloser interface {
+ wire.Writer
+ io.Closer
+}
+
+// NewWriter returns a state data writer for a statefile.
+//
+// Note that the returned WriteCloser must be closed.
+func NewWriter(w io.Writer, key []byte, metadata map[string]string) (WriteCloser, error) {
+ if metadata == nil {
+ metadata = make(map[string]string)
+ }
+ for k := range metadata {
+ if strings.HasPrefix(k, "_") {
+ return nil, ErrMetadataInvalid
+ }
+ }
+
+ // Create our HMAC function.
+ h := hmac.New(sha256.New, key)
+ mw := io.MultiWriter(w, h)
+
+ // First, write the header.
+ if _, err := mw.Write(magicHeader); err != nil {
+ return nil, err
+ }
+
+ // Generate a timestamp, for convenience only.
+ metadata["_timestamp"] = time.Now().UTC().String()
+ defer delete(metadata, "_timestamp")
+
+ // Write the metadata.
+ b, err := json.Marshal(metadata)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(b) > maxMetadataSize {
+ return nil, ErrInvalidMetadataLength
+ }
+
+ // Metadata length.
+ if err := binary.WriteUint64(mw, binary.BigEndian, uint64(len(b))); err != nil {
+ return nil, err
+ }
+ // Metadata bytes; io.MultiWriter will return a short write error if
+ // any of the writers returns < n.
+ if _, err := mw.Write(b); err != nil {
+ return nil, err
+ }
+ // Write the current hash.
+ cur := h.Sum(nil)
+ for done := 0; done < len(cur); {
+ n, err := mw.Write(cur[done:])
+ done += n
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // Wrap in compression. We always use "best speed" mode here. When using
+ // "best compression" mode, there is usually only a little gain in file
+ // size reduction, which translate to even smaller gain in restore
+ // latency reduction, while inccuring much more CPU usage at save time.
+ return compressio.NewWriter(w, key, compressionChunkSize, flate.BestSpeed)
+}
+
+// MetadataUnsafe reads out the metadata from a state file without verifying any
+// HMAC. This function shouldn't be called for untrusted input files.
+func MetadataUnsafe(r io.Reader) (map[string]string, error) {
+ return metadata(r, nil)
+}
+
+// metadata validates the magic header and reads out the metadata from a state
+// data stream.
+func metadata(r io.Reader, h hash.Hash) (map[string]string, error) {
+ if h != nil {
+ r = io.TeeReader(r, h)
+ }
+
+ // Read and validate magic header.
+ b := make([]byte, len(magicHeader))
+ if _, err := r.Read(b); err != nil {
+ return nil, err
+ }
+ if !bytes.Equal(b, magicHeader) {
+ return nil, ErrBadMagic
+ }
+
+ // Read and validate metadata.
+ b, err := func() (b []byte, err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ b = nil
+ err = fmt.Errorf("%v", r)
+ }
+ }()
+
+ metadataLen, err := binary.ReadUint64(r, binary.BigEndian)
+ if err != nil {
+ return nil, err
+ }
+ if metadataLen > maxMetadataSize {
+ return nil, ErrInvalidMetadataLength
+ }
+ b = make([]byte, int(metadataLen))
+ if _, err := io.ReadFull(r, b); err != nil {
+ return nil, err
+ }
+ return b, nil
+ }()
+ if err != nil {
+ return nil, err
+ }
+
+ if h != nil {
+ // Check the hash prior to decoding.
+ cur := h.Sum(nil)
+ buf := make([]byte, len(cur))
+ if _, err := io.ReadFull(r, buf); err != nil {
+ return nil, err
+ }
+ if !hmac.Equal(cur, buf) {
+ return nil, compressio.ErrHashMismatch
+ }
+ }
+
+ // Decode the metadata.
+ metadata := make(map[string]string)
+ if err := json.Unmarshal(b, &metadata); err != nil {
+ return nil, err
+ }
+
+ return metadata, nil
+}
+
+// NewReader returns a reader for a statefile.
+func NewReader(r io.Reader, key []byte) (wire.Reader, map[string]string, error) {
+ // Read the metadata with the hash.
+ h := hmac.New(sha256.New, key)
+ metadata, err := metadata(r, h)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Wrap in compression.
+ cr, err := compressio.NewReader(r, key)
+ if err != nil {
+ return nil, nil, err
+ }
+ return cr, metadata, nil
+}
diff --git a/pkg/state/statefile/statefile_test.go b/pkg/state/statefile/statefile_test.go
new file mode 100644
index 000000000..0b470fdec
--- /dev/null
+++ b/pkg/state/statefile/statefile_test.go
@@ -0,0 +1,290 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package statefile
+
+import (
+ "bytes"
+ crand "crypto/rand"
+ "encoding/base64"
+ "io"
+ "math/rand"
+ "runtime"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/compressio"
+)
+
+func randomKey() ([]byte, error) {
+ r := make([]byte, base64.RawStdEncoding.DecodedLen(keySize))
+ if _, err := io.ReadFull(crand.Reader, r); err != nil {
+ return nil, err
+ }
+ key := make([]byte, keySize)
+ base64.RawStdEncoding.Encode(key, r)
+ return key, nil
+}
+
+type testCase struct {
+ name string
+ data []byte
+ metadata map[string]string
+}
+
+func TestStatefile(t *testing.T) {
+ rand.Seed(time.Now().Unix())
+
+ cases := []testCase{
+ // Various data sizes.
+ {"nil", nil, nil},
+ {"empty", []byte(""), nil},
+ {"some", []byte("_"), nil},
+ {"one", []byte("0"), nil},
+ {"two", []byte("01"), nil},
+ {"three", []byte("012"), nil},
+ {"four", []byte("0123"), nil},
+ {"five", []byte("01234"), nil},
+ {"six", []byte("012356"), nil},
+ {"seven", []byte("0123567"), nil},
+ {"eight", []byte("01235678"), nil},
+
+ // Make sure we have one longer than the hash length.
+ {"longer than hash", []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"), nil},
+
+ // Make sure we have one longer than the chunk size.
+ {"chunks", make([]byte, 3*compressionChunkSize), nil},
+ {"large", make([]byte, 30*compressionChunkSize), nil},
+
+ // Different metadata.
+ {"one metadata", []byte("data"), map[string]string{"foo": "bar"}},
+ {"two metadata", []byte("data"), map[string]string{"foo": "bar", "one": "two"}},
+ }
+
+ for _, c := range cases {
+ // Generate a key.
+ integrityKey, err := randomKey()
+ if err != nil {
+ t.Errorf("can't generate key: got %v, excepted nil", err)
+ continue
+ }
+
+ t.Run(c.name, func(t *testing.T) {
+ for _, key := range [][]byte{nil, integrityKey} {
+ t.Run("key="+string(key), func(t *testing.T) {
+ // Encoding happens via a buffer.
+ var bufEncoded bytes.Buffer
+ var bufDecoded bytes.Buffer
+
+ // Do all the writing.
+ w, err := NewWriter(&bufEncoded, key, c.metadata)
+ if err != nil {
+ t.Fatalf("error creating writer: got %v, expected nil", err)
+ }
+ if _, err := io.Copy(w, bytes.NewBuffer(c.data)); err != nil {
+ t.Fatalf("error during write: got %v, expected nil", err)
+ }
+
+ // Finish the sum.
+ if err := w.Close(); err != nil {
+ t.Fatalf("error during close: got %v, expected nil", err)
+ }
+
+ t.Logf("original data: %d bytes, encoded: %d bytes.",
+ len(c.data), len(bufEncoded.Bytes()))
+
+ // Do all the reading.
+ r, metadata, err := NewReader(bytes.NewReader(bufEncoded.Bytes()), key)
+ if err != nil {
+ t.Fatalf("error creating reader: got %v, expected nil", err)
+ }
+ if _, err := io.Copy(&bufDecoded, r); err != nil {
+ t.Fatalf("error during read: got %v, expected nil", err)
+ }
+
+ // Check that the data matches.
+ if !bytes.Equal(c.data, bufDecoded.Bytes()) {
+ t.Fatalf("data didn't match (%d vs %d bytes)", len(bufDecoded.Bytes()), len(c.data))
+ }
+
+ // Check that the metadata matches.
+ for k, v := range c.metadata {
+ nv, ok := metadata[k]
+ if !ok {
+ t.Fatalf("missing metadata: %s", k)
+ }
+ if v != nv {
+ t.Fatalf("mismatched metdata for %s: got %s, expected %s", k, nv, v)
+ }
+ }
+
+ // Change the data and verify that it fails.
+ if key != nil {
+ b := append([]byte(nil), bufEncoded.Bytes()...)
+ b[rand.Intn(len(b))]++
+ bufDecoded.Reset()
+ r, _, err = NewReader(bytes.NewReader(b), key)
+ if err == nil {
+ _, err = io.Copy(&bufDecoded, r)
+ }
+ if err == nil {
+ t.Error("got no error: expected error on data corruption")
+ }
+ }
+
+ // Change the key and verify that it fails.
+ newKey := integrityKey
+ if len(key) > 0 {
+ newKey = append([]byte{}, key...)
+ newKey[rand.Intn(len(newKey))]++
+ }
+ bufDecoded.Reset()
+ r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), newKey)
+ if err == nil {
+ _, err = io.Copy(&bufDecoded, r)
+ }
+ if err != compressio.ErrHashMismatch {
+ t.Errorf("got error: %v, expected ErrHashMismatch on key mismatch", err)
+ }
+ })
+ }
+ })
+ }
+}
+
+const benchmarkDataSize = 100 * 1024 * 1024
+
+func benchmark(b *testing.B, size int, write bool, compressible bool) {
+ b.StopTimer()
+ b.SetBytes(benchmarkDataSize)
+
+ // Generate source data.
+ var source []byte
+ if compressible {
+ // For compressible data, we use essentially all zeros.
+ source = make([]byte, benchmarkDataSize)
+ } else {
+ // For non-compressible data, we use random base64 data (to
+ // make it marginally compressible, a ratio of 75%).
+ var sourceBuf bytes.Buffer
+ bufW := base64.NewEncoder(base64.RawStdEncoding, &sourceBuf)
+ bufR := rand.New(rand.NewSource(0))
+ if _, err := io.CopyN(bufW, bufR, benchmarkDataSize); err != nil {
+ b.Fatalf("unable to seed random data: %v", err)
+ }
+ source = sourceBuf.Bytes()
+ }
+
+ // Generate a random key for integrity check.
+ key, err := randomKey()
+ if err != nil {
+ b.Fatalf("error generating key: %v", err)
+ }
+
+ // Define our benchmark functions. Prior to running the readState
+ // function here, you must execute the writeState function at least
+ // once (done below).
+ var stateBuf bytes.Buffer
+ writeState := func() {
+ stateBuf.Reset()
+ w, err := NewWriter(&stateBuf, key, nil)
+ if err != nil {
+ b.Fatalf("error creating writer: %v", err)
+ }
+ for done := 0; done < len(source); {
+ chunk := size // limit size.
+ if done+chunk > len(source) {
+ chunk = len(source) - done
+ }
+ n, err := w.Write(source[done : done+chunk])
+ done += n
+ if n == 0 && err != nil {
+ b.Fatalf("error during write: %v", err)
+ }
+ }
+ if err := w.Close(); err != nil {
+ b.Fatalf("error closing writer: %v", err)
+ }
+ }
+ readState := func() {
+ tmpBuf := bytes.NewBuffer(stateBuf.Bytes())
+ r, _, err := NewReader(tmpBuf, key)
+ if err != nil {
+ b.Fatalf("error creating reader: %v", err)
+ }
+ for done := 0; done < len(source); {
+ chunk := size // limit size.
+ if done+chunk > len(source) {
+ chunk = len(source) - done
+ }
+ n, err := r.Read(source[done : done+chunk])
+ done += n
+ if n == 0 && err != nil {
+ b.Fatalf("error during read: %v", err)
+ }
+ }
+ }
+ // Generate the state once without timing to ensure that buffers have
+ // been appropriately allocated.
+ writeState()
+ if write {
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ writeState()
+ }
+ b.StopTimer()
+ } else {
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ readState()
+ }
+ b.StopTimer()
+ }
+}
+
+func BenchmarkWrite4KCompressible(b *testing.B) {
+ benchmark(b, 4096, true, true)
+}
+
+func BenchmarkWrite4KNoncompressible(b *testing.B) {
+ benchmark(b, 4096, true, false)
+}
+
+func BenchmarkWrite1MCompressible(b *testing.B) {
+ benchmark(b, 1024*1024, true, true)
+}
+
+func BenchmarkWrite1MNoncompressible(b *testing.B) {
+ benchmark(b, 1024*1024, true, false)
+}
+
+func BenchmarkRead4KCompressible(b *testing.B) {
+ benchmark(b, 4096, false, true)
+}
+
+func BenchmarkRead4KNoncompressible(b *testing.B) {
+ benchmark(b, 4096, false, false)
+}
+
+func BenchmarkRead1MCompressible(b *testing.B) {
+ benchmark(b, 1024*1024, false, true)
+}
+
+func BenchmarkRead1MNoncompressible(b *testing.B) {
+ benchmark(b, 1024*1024, false, false)
+}
+
+func init() {
+ runtime.GOMAXPROCS(runtime.NumCPU())
+}
diff --git a/pkg/state/stats.go b/pkg/state/stats.go
new file mode 100644
index 000000000..eaec664a1
--- /dev/null
+++ b/pkg/state/stats.go
@@ -0,0 +1,145 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "time"
+)
+
+type statEntry struct {
+ count uint
+ total time.Duration
+}
+
+// Stats tracks encode / decode timing.
+//
+// This currently provides a meaningful String function and no other way to
+// extract stats about individual types.
+//
+// All exported receivers accept nil.
+type Stats struct {
+ // byType contains a breakdown of time spent by type.
+ //
+ // This is indexed *directly* by typeID, including zero.
+ byType []statEntry
+
+ // stack contains objects in progress.
+ stack []typeID
+
+ // names contains type names.
+ //
+ // This is also indexed *directly* by typeID, including zero, which we
+ // hard-code as "state.default". This is only resolved by calling fini
+ // on the stats object.
+ names []string
+
+ // last is the last start time.
+ last time.Time
+}
+
+// init initializes statistics.
+func (s *Stats) init() {
+ s.last = time.Now()
+ s.stack = append(s.stack, 0)
+}
+
+// fini finalizes statistics.
+func (s *Stats) fini(resolve func(id typeID) string) {
+ s.done()
+
+ // Resolve all type names.
+ s.names = make([]string, len(s.byType))
+ s.names[0] = "state.default" // See above.
+ for id := typeID(1); int(id) < len(s.names); id++ {
+ s.names[id] = resolve(id)
+ }
+}
+
+// sample adds the samples to the given object.
+func (s *Stats) sample(id typeID) {
+ now := time.Now()
+ if len(s.byType) <= int(id) {
+ // Allocate all the missing entries in one fell swoop.
+ s.byType = append(s.byType, make([]statEntry, 1+int(id)-len(s.byType))...)
+ }
+ s.byType[id].total += now.Sub(s.last)
+ s.last = now
+}
+
+// start starts a sample.
+func (s *Stats) start(id typeID) {
+ last := s.stack[len(s.stack)-1]
+ s.sample(last)
+ s.stack = append(s.stack, id)
+}
+
+// done finishes the current sample.
+func (s *Stats) done() {
+ last := s.stack[len(s.stack)-1]
+ s.sample(last)
+ s.byType[last].count++
+ s.stack = s.stack[:len(s.stack)-1]
+}
+
+type sliceEntry struct {
+ name string
+ entry *statEntry
+}
+
+// String returns a table representation of the stats.
+func (s *Stats) String() string {
+ // Build a list of stat entries.
+ ss := make([]sliceEntry, 0, len(s.byType))
+ for id := 0; id < len(s.names); id++ {
+ ss = append(ss, sliceEntry{
+ name: s.names[id],
+ entry: &s.byType[id],
+ })
+ }
+
+ // Sort by total time (descending).
+ sort.Slice(ss, func(i, j int) bool {
+ return ss[i].entry.total > ss[j].entry.total
+ })
+
+ // Print the stat results.
+ var (
+ buf bytes.Buffer
+ count uint
+ total time.Duration
+ )
+ buf.WriteString("\n")
+ buf.WriteString(fmt.Sprintf("% 16s | % 8s | % 16s | %s\n", "total", "count", "per", "type"))
+ buf.WriteString("-----------------+----------+------------------+----------------\n")
+ for _, se := range ss {
+ if se.entry.count == 0 {
+ // Since we store all types linearly, we are not
+ // guaranteed that any entry actually has time.
+ continue
+ }
+ count += se.entry.count
+ total += se.entry.total
+ per := se.entry.total / time.Duration(se.entry.count)
+ buf.WriteString(fmt.Sprintf("% 16s | %8d | % 16s | %s\n",
+ se.entry.total, se.entry.count, per, se.name))
+ }
+ buf.WriteString("-----------------+----------+------------------+----------------\n")
+ buf.WriteString(fmt.Sprintf("% 16s | % 8d | % 16s | [all]",
+ total, count, total/time.Duration(count)))
+ return string(buf.Bytes())
+}
diff --git a/pkg/state/tests/BUILD b/pkg/state/tests/BUILD
new file mode 100644
index 000000000..9297cafbe
--- /dev/null
+++ b/pkg/state/tests/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tests",
+ srcs = [
+ "array.go",
+ "bench.go",
+ "integer.go",
+ "load.go",
+ "map.go",
+ "register.go",
+ "struct.go",
+ "tests.go",
+ ],
+ deps = [
+ "//pkg/state",
+ "//pkg/state/pretty",
+ ],
+)
+
+go_test(
+ name = "tests_test",
+ size = "small",
+ srcs = [
+ "array_test.go",
+ "bench_test.go",
+ "bool_test.go",
+ "float_test.go",
+ "integer_test.go",
+ "load_test.go",
+ "map_test.go",
+ "register_test.go",
+ "string_test.go",
+ "struct_test.go",
+ ],
+ library = ":tests",
+ deps = [
+ "//pkg/state",
+ "//pkg/state/wire",
+ ],
+)
diff --git a/pkg/state/tests/array.go b/pkg/state/tests/array.go
new file mode 100644
index 000000000..0972a80e7
--- /dev/null
+++ b/pkg/state/tests/array.go
@@ -0,0 +1,35 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type arrayContainer struct {
+ v [1]interface{}
+}
+
+// +stateify savable
+type arrayPtrContainer struct {
+ v *[1]interface{}
+}
+
+// +stateify savable
+type sliceContainer struct {
+ v []interface{}
+}
+
+// +stateify savable
+type slicePtrContainer struct {
+ v *[]interface{}
+}
diff --git a/pkg/state/tests/array_test.go b/pkg/state/tests/array_test.go
new file mode 100644
index 000000000..a347b2947
--- /dev/null
+++ b/pkg/state/tests/array_test.go
@@ -0,0 +1,134 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "reflect"
+ "testing"
+)
+
+var allArrayPrimitives = []interface{}{
+ [1]bool{},
+ [1]bool{true},
+ [2]bool{false, true},
+ [1]int{},
+ [1]int{1},
+ [2]int{0, 1},
+ [1]int8{},
+ [1]int8{1},
+ [2]int8{0, 1},
+ [1]int16{},
+ [1]int16{1},
+ [2]int16{0, 1},
+ [1]int32{},
+ [1]int32{1},
+ [2]int32{0, 1},
+ [1]int64{},
+ [1]int64{1},
+ [2]int64{0, 1},
+ [1]uint{},
+ [1]uint{1},
+ [2]uint{0, 1},
+ [1]uintptr{},
+ [1]uintptr{1},
+ [2]uintptr{0, 1},
+ [1]uint8{},
+ [1]uint8{1},
+ [2]uint8{0, 1},
+ [1]uint16{},
+ [1]uint16{1},
+ [2]uint16{0, 1},
+ [1]uint32{},
+ [1]uint32{1},
+ [2]uint32{0, 1},
+ [1]uint64{},
+ [1]uint64{1},
+ [2]uint64{0, 1},
+ [1]string{},
+ [1]string{""},
+ [1]string{nonEmptyString},
+ [2]string{"", nonEmptyString},
+}
+
+func TestArrayPrimitives(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(allArrayPrimitives))
+ runTestCases(t, false, "pointers", pointersTo(flatten(allArrayPrimitives)))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(allArrayPrimitives)))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allArrayPrimitives))))
+}
+
+func TestSlices(t *testing.T) {
+ var allSlices = flatten(
+ filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o)).Elem()
+ v.Set(reflect.ValueOf(o))
+ return v.Slice(0, v.Len()).Interface(), true
+ }),
+ filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o)).Elem()
+ v.Set(reflect.ValueOf(o))
+ if v.Len() == 0 {
+ // Return the pure "nil" value for the slice.
+ return reflect.New(v.Slice(0, 0).Type()).Elem().Interface(), true
+ }
+ return v.Slice(1, v.Len()).Interface(), true
+ }),
+ filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o)).Elem()
+ v.Set(reflect.ValueOf(o))
+ if v.Len() == 0 {
+ // Return the zero-valued slice.
+ return reflect.MakeSlice(v.Slice(0, 0).Type(), 0, 0).Interface(), true
+ }
+ return v.Slice(0, v.Len()-1).Interface(), true
+ }),
+ )
+ runTestCases(t, false, "plain", allSlices)
+ runTestCases(t, false, "pointers", pointersTo(allSlices))
+ runTestCases(t, false, "interfaces", interfacesTo(allSlices))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(allSlices)))
+}
+
+func TestArrayContainers(t *testing.T) {
+ var (
+ emptyArray [1]interface{}
+ fullArray [1]interface{}
+ )
+ fullArray[0] = &emptyArray
+ runTestCases(t, false, "", []interface{}{
+ arrayContainer{v: emptyArray},
+ arrayContainer{v: fullArray},
+ arrayPtrContainer{v: nil},
+ arrayPtrContainer{v: &emptyArray},
+ arrayPtrContainer{v: &fullArray},
+ })
+}
+
+func TestSliceContainers(t *testing.T) {
+ var (
+ nilSlice []interface{}
+ emptySlice = make([]interface{}, 0)
+ fullSlice = []interface{}{nil}
+ )
+ runTestCases(t, false, "", []interface{}{
+ sliceContainer{v: nilSlice},
+ sliceContainer{v: emptySlice},
+ sliceContainer{v: fullSlice},
+ slicePtrContainer{v: nil},
+ slicePtrContainer{v: &nilSlice},
+ slicePtrContainer{v: &emptySlice},
+ slicePtrContainer{v: &fullSlice},
+ })
+}
diff --git a/pkg/state/tests/bench.go b/pkg/state/tests/bench.go
new file mode 100644
index 000000000..40869cdfb
--- /dev/null
+++ b/pkg/state/tests/bench.go
@@ -0,0 +1,24 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type benchStruct struct {
+ B *benchStruct // Must be exported for gob.
+}
+
+func (b *benchStruct) afterLoad() {
+ // Do nothing, just force scheduling.
+}
diff --git a/pkg/state/tests/bench_test.go b/pkg/state/tests/bench_test.go
new file mode 100644
index 000000000..7e102c907
--- /dev/null
+++ b/pkg/state/tests/bench_test.go
@@ -0,0 +1,153 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "bytes"
+ "context"
+ "encoding/gob"
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// buildPtrObject builds a benchmark object.
+func buildPtrObject(n int) interface{} {
+ b := new(benchStruct)
+ for i := 0; i < n; i++ {
+ b = &benchStruct{B: b}
+ }
+ return b
+}
+
+// buildMapObject builds a benchmark object.
+func buildMapObject(n int) interface{} {
+ b := new(benchStruct)
+ m := make(map[int]*benchStruct)
+ for i := 0; i < n; i++ {
+ m[i] = b
+ }
+ return &m
+}
+
+// buildSliceObject builds a benchmark object.
+func buildSliceObject(n int) interface{} {
+ b := new(benchStruct)
+ s := make([]*benchStruct, 0, n)
+ for i := 0; i < n; i++ {
+ s = append(s, b)
+ }
+ return &s
+}
+
+var allObjects = map[string]struct {
+ New func(int) interface{}
+}{
+ "ptr": {
+ New: buildPtrObject,
+ },
+ "map": {
+ New: buildMapObject,
+ },
+ "slice": {
+ New: buildSliceObject,
+ },
+}
+
+func buildObjects(n int, fn func(int) interface{}) (iters int, v interface{}) {
+ // maxSize is the maximum size of an individual object below. For an N
+ // larger than this, we start to return multiple objects.
+ const maxSize = 1024
+ if n <= maxSize {
+ return 1, fn(n)
+ }
+ iters = (n + maxSize - 1) / maxSize
+ return iters, fn(maxSize)
+}
+
+// gobSave is a version of save using gob (no stats available).
+func gobSave(_ context.Context, w wire.Writer, v interface{}) (_ state.Stats, err error) {
+ enc := gob.NewEncoder(w)
+ err = enc.Encode(v)
+ return
+}
+
+// gobLoad is a version of load using gob (no stats available).
+func gobLoad(_ context.Context, r wire.Reader, v interface{}) (_ state.Stats, err error) {
+ dec := gob.NewDecoder(r)
+ err = dec.Decode(v)
+ return
+}
+
+var allAlgos = map[string]struct {
+ Save func(context.Context, wire.Writer, interface{}) (state.Stats, error)
+ Load func(context.Context, wire.Reader, interface{}) (state.Stats, error)
+ MaxPtr int
+}{
+ "state": {
+ Save: state.Save,
+ Load: state.Load,
+ },
+ "gob": {
+ Save: gobSave,
+ Load: gobLoad,
+ },
+}
+
+func BenchmarkEncoding(b *testing.B) {
+ for objName, objInfo := range allObjects {
+ for algoName, algoInfo := range allAlgos {
+ b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) {
+ b.StopTimer()
+ n, v := buildObjects(b.N, objInfo.New)
+ b.ReportAllocs()
+ b.StartTimer()
+ for i := 0; i < n; i++ {
+ if _, err := algoInfo.Save(context.Background(), discard{}, v); err != nil {
+ b.Errorf("save failed: %v", err)
+ }
+ }
+ b.StopTimer()
+ })
+ }
+ }
+}
+
+func BenchmarkDecoding(b *testing.B) {
+ for objName, objInfo := range allObjects {
+ for algoName, algoInfo := range allAlgos {
+ b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) {
+ b.StopTimer()
+ n, v := buildObjects(b.N, objInfo.New)
+ buf := new(bytes.Buffer)
+ if _, err := algoInfo.Save(context.Background(), buf, v); err != nil {
+ b.Errorf("save failed: %v", err)
+ }
+ b.ReportAllocs()
+ b.StartTimer()
+ var r bytes.Reader
+ for i := 0; i < n; i++ {
+ r.Reset(buf.Bytes())
+ if _, err := algoInfo.Load(context.Background(), &r, v); err != nil {
+ b.Errorf("load failed: %v", err)
+ }
+ }
+ b.StopTimer()
+ })
+ }
+ }
+}
diff --git a/pkg/state/tests/bool_test.go b/pkg/state/tests/bool_test.go
new file mode 100644
index 000000000..e17cfacf9
--- /dev/null
+++ b/pkg/state/tests/bool_test.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+)
+
+var allBools = []bool{
+ true,
+ false,
+}
+
+func TestBool(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(allBools))
+ runTestCases(t, false, "pointers", pointersTo(flatten(allBools)))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(allBools)))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allBools))))
+}
diff --git a/pkg/state/tests/float_test.go b/pkg/state/tests/float_test.go
new file mode 100644
index 000000000..3e89edd9c
--- /dev/null
+++ b/pkg/state/tests/float_test.go
@@ -0,0 +1,118 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "math"
+ "testing"
+)
+
+var safeFloat32s = []float32{
+ float32(0.0),
+ float32(1.0),
+ float32(-1.0),
+ float32(math.Inf(1)),
+ float32(math.Inf(-1)),
+}
+
+var allFloat32s = append(safeFloat32s, float32(math.NaN()))
+
+var safeFloat64s = []float64{
+ float64(0.0),
+ float64(1.0),
+ float64(-1.0),
+ math.Inf(1),
+ math.Inf(-1),
+}
+
+var allFloat64s = append(safeFloat64s, math.NaN())
+
+func TestFloat(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(
+ allFloat32s,
+ allFloat64s,
+ ))
+ // See checkEqual for why NaNs are missing.
+ runTestCases(t, false, "pointers", pointersTo(flatten(
+ safeFloat32s,
+ safeFloat64s,
+ )))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(
+ safeFloat32s,
+ safeFloat64s,
+ )))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(
+ safeFloat32s,
+ safeFloat64s,
+ ))))
+}
+
+const onlyDouble float64 = 1.0000000000000002
+
+func TestFloatTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingFloat32{save: onlyDouble},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingFloat32{save: 1.0},
+ })
+}
+
+var safeComplex64s = combine(safeFloat32s, safeFloat32s, func(i, j interface{}) interface{} {
+ return complex(i.(float32), j.(float32))
+})
+
+var allComplex64s = combine(allFloat32s, allFloat32s, func(i, j interface{}) interface{} {
+ return complex(i.(float32), j.(float32))
+})
+
+var safeComplex128s = combine(safeFloat64s, safeFloat64s, func(i, j interface{}) interface{} {
+ return complex(i.(float64), j.(float64))
+})
+
+var allComplex128s = combine(allFloat64s, allFloat64s, func(i, j interface{}) interface{} {
+ return complex(i.(float64), j.(float64))
+})
+
+func TestComplex(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(
+ allComplex64s,
+ allComplex128s,
+ ))
+ // See TestFloat; same issue.
+ runTestCases(t, false, "pointers", pointersTo(flatten(
+ safeComplex64s,
+ safeComplex128s,
+ )))
+ runTestCases(t, false, "interfacse", interfacesTo(flatten(
+ safeComplex64s,
+ safeComplex128s,
+ )))
+ runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(flatten(
+ safeComplex64s,
+ safeComplex128s,
+ ))))
+}
+
+func TestComplexTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingComplex64{save: complex(onlyDouble, onlyDouble)},
+ truncatingComplex64{save: complex(1.0, onlyDouble)},
+ truncatingComplex64{save: complex(onlyDouble, 1.0)},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingComplex64{save: complex(1.0, 1.0)},
+ })
+}
diff --git a/pkg/state/tests/integer.go b/pkg/state/tests/integer.go
new file mode 100644
index 000000000..ca403eed1
--- /dev/null
+++ b/pkg/state/tests/integer.go
@@ -0,0 +1,163 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+// +stateify type
+type truncatingUint8 struct {
+ save uint64
+ load uint8 `state:"nosave"`
+}
+
+func (t *truncatingUint8) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingUint8) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = uint64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingUint8)(nil)
+
+// +stateify type
+type truncatingUint16 struct {
+ save uint64
+ load uint16 `state:"nosave"`
+}
+
+func (t *truncatingUint16) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingUint16) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = uint64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingUint16)(nil)
+
+// +stateify type
+type truncatingUint32 struct {
+ save uint64
+ load uint32 `state:"nosave"`
+}
+
+func (t *truncatingUint32) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingUint32) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = uint64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingUint32)(nil)
+
+// +stateify type
+type truncatingInt8 struct {
+ save int64
+ load int8 `state:"nosave"`
+}
+
+func (t *truncatingInt8) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingInt8) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = int64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingInt8)(nil)
+
+// +stateify type
+type truncatingInt16 struct {
+ save int64
+ load int16 `state:"nosave"`
+}
+
+func (t *truncatingInt16) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingInt16) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = int64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingInt16)(nil)
+
+// +stateify type
+type truncatingInt32 struct {
+ save int64
+ load int32 `state:"nosave"`
+}
+
+func (t *truncatingInt32) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingInt32) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = int64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingInt32)(nil)
+
+// +stateify type
+type truncatingFloat32 struct {
+ save float64
+ load float32 `state:"nosave"`
+}
+
+func (t *truncatingFloat32) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingFloat32) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = float64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingFloat32)(nil)
+
+// +stateify type
+type truncatingComplex64 struct {
+ save complex128
+ load complex64 `state:"nosave"`
+}
+
+func (t *truncatingComplex64) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingComplex64) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = complex128(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingComplex64)(nil)
diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go
new file mode 100644
index 000000000..d3931c952
--- /dev/null
+++ b/pkg/state/tests/integer_test.go
@@ -0,0 +1,94 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "math"
+ "testing"
+)
+
+var (
+ allIntTs = []int{-1, 0, 1}
+ allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8}
+ allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16}
+ allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32}
+ allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64}
+ allUintTs = []uint{0, 1}
+ allUintptrs = []uintptr{0, 1, ^uintptr(0)}
+ allUint8s = []uint8{0, 1, math.MaxUint8}
+ allUint16s = []uint16{0, 1, math.MaxUint16}
+ allUint32s = []uint32{0, 1, math.MaxUint32}
+ allUint64s = []uint64{0, 1, math.MaxUint64}
+)
+
+var allInts = flatten(
+ allIntTs,
+ allInt8s,
+ allInt16s,
+ allInt32s,
+ allInt64s,
+)
+
+var allUints = flatten(
+ allUintTs,
+ allUintptrs,
+ allUint8s,
+ allUint16s,
+ allUint32s,
+ allUint64s,
+)
+
+func TestInt(t *testing.T) {
+ runTestCases(t, false, "plain", allInts)
+ runTestCases(t, false, "pointers", pointersTo(allInts))
+ runTestCases(t, false, "interfaces", interfacesTo(allInts))
+ runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allInts)))
+}
+
+func TestIntTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingInt8{save: math.MinInt8 - 1},
+ truncatingInt16{save: math.MinInt16 - 1},
+ truncatingInt32{save: math.MinInt32 - 1},
+ truncatingInt8{save: math.MaxInt8 + 1},
+ truncatingInt16{save: math.MaxInt16 + 1},
+ truncatingInt32{save: math.MaxInt32 + 1},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingInt8{save: 1},
+ truncatingInt16{save: 1},
+ truncatingInt32{save: 1},
+ })
+}
+
+func TestUint(t *testing.T) {
+ runTestCases(t, false, "plain", allUints)
+ runTestCases(t, false, "pointers", pointersTo(allUints))
+ runTestCases(t, false, "interfaces", interfacesTo(allUints))
+ runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allUints)))
+}
+
+func TestUintTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingUint8{save: math.MaxUint8 + 1},
+ truncatingUint16{save: math.MaxUint16 + 1},
+ truncatingUint32{save: math.MaxUint32 + 1},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingUint8{save: 1},
+ truncatingUint16{save: 1},
+ truncatingUint32{save: 1},
+ })
+}
diff --git a/pkg/state/tests/load.go b/pkg/state/tests/load.go
new file mode 100644
index 000000000..a8350c0f3
--- /dev/null
+++ b/pkg/state/tests/load.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type genericContainer struct {
+ v interface{}
+}
+
+// +stateify savable
+type afterLoadStruct struct {
+ v int `state:"nosave"`
+}
+
+func (a *afterLoadStruct) afterLoad() {
+ a.v++
+}
+
+// +stateify savable
+type valueLoadStruct struct {
+ v int `state:".(int64)"`
+}
+
+func (v *valueLoadStruct) saveV() int64 {
+ return int64(v.v) // Save as int64.
+}
+
+func (v *valueLoadStruct) loadV(value int64) {
+ v.v = int(value) // Load as int.
+}
+
+// +stateify savable
+type cycleStruct struct {
+ c *cycleStruct
+}
+
+// +stateify savable
+type badCycleStruct struct {
+ b *badCycleStruct `state:"wait"`
+}
+
+func (b *badCycleStruct) afterLoad() {
+ if b.b != b {
+ // This is not executable, since AfterLoad requires that the
+ // object and all dependencies are complete. This should cause
+ // a deadlock error during load.
+ panic("badCycleStruct.afterLoad called")
+ }
+}
diff --git a/pkg/state/tests/load_test.go b/pkg/state/tests/load_test.go
new file mode 100644
index 000000000..1e9794296
--- /dev/null
+++ b/pkg/state/tests/load_test.go
@@ -0,0 +1,70 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+)
+
+func TestLoadHooks(t *testing.T) {
+ runTestCases(t, false, "load-hooks", []interface{}{
+ &afterLoadStruct{v: 1},
+ &valueLoadStruct{v: 1},
+ &genericContainer{v: &afterLoadStruct{v: 1}},
+ &genericContainer{v: &valueLoadStruct{v: 1}},
+ &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}},
+ &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}},
+ &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}},
+ &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}},
+ })
+}
+
+func TestCycles(t *testing.T) {
+ // cs is a single object cycle.
+ cs := cycleStruct{nil}
+ cs.c = &cs
+
+ // cs1 and cs2 are in a two object cycle.
+ cs1 := cycleStruct{nil}
+ cs2 := cycleStruct{nil}
+ cs1.c = &cs2
+ cs2.c = &cs1
+
+ runTestCases(t, false, "cycles", []interface{}{
+ cs,
+ cs1,
+ })
+}
+
+func TestDeadlock(t *testing.T) {
+ // bs is a single object cycle. This does not cause deadlock because an
+ // object cannot wait for itself.
+ bs := badCycleStruct{nil}
+ bs.b = &bs
+
+ runTestCases(t, false, "self", []interface{}{
+ &bs,
+ })
+
+ // bs2 and bs2 are in a deadlocking cycle.
+ bs1 := badCycleStruct{nil}
+ bs2 := badCycleStruct{nil}
+ bs1.b = &bs2
+ bs2.b = &bs1
+
+ runTestCases(t, true, "deadlock", []interface{}{
+ &bs1,
+ })
+}
diff --git a/pkg/state/tests/map.go b/pkg/state/tests/map.go
new file mode 100644
index 000000000..db4e548f1
--- /dev/null
+++ b/pkg/state/tests/map.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type mapContainer struct {
+ v map[int]interface{}
+}
+
+// +stateify savable
+type mapPtrContainer struct {
+ v *map[int]interface{}
+}
+
+// +stateify savable
+type registeredMapStruct struct{}
diff --git a/pkg/state/tests/map_test.go b/pkg/state/tests/map_test.go
new file mode 100644
index 000000000..92bf0fc01
--- /dev/null
+++ b/pkg/state/tests/map_test.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "reflect"
+ "testing"
+)
+
+var allMapPrimitives = []interface{}{
+ bool(true),
+ int(1),
+ int8(1),
+ int16(1),
+ int32(1),
+ int64(1),
+ uint(1),
+ uintptr(1),
+ uint8(1),
+ uint16(1),
+ uint32(1),
+ uint64(1),
+ string(""),
+ registeredMapStruct{},
+}
+
+var allMapKeys = flatten(allMapPrimitives, pointersTo(allMapPrimitives))
+
+var allMapValues = flatten(allMapPrimitives, pointersTo(allMapPrimitives), interfacesTo(allMapPrimitives))
+
+var emptyMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} {
+ m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2)))
+ return m.Interface()
+})
+
+var fullMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} {
+ m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2)))
+ m.SetMapIndex(reflect.Zero(reflect.TypeOf(v1)), reflect.Zero(reflect.TypeOf(v2)))
+ return m.Interface()
+})
+
+func TestMapAliasing(t *testing.T) {
+ v := make(map[int]int)
+ ptrToV := &v
+ aliases := []map[int]int{v, v}
+ runTestCases(t, false, "", []interface{}{ptrToV, aliases})
+}
+
+func TestMapsEmpty(t *testing.T) {
+ runTestCases(t, false, "plain", emptyMaps)
+ runTestCases(t, false, "pointers", pointersTo(emptyMaps))
+ runTestCases(t, false, "interfaces", interfacesTo(emptyMaps))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(emptyMaps)))
+}
+
+func TestMapsFull(t *testing.T) {
+ runTestCases(t, false, "plain", fullMaps)
+ runTestCases(t, false, "pointers", pointersTo(fullMaps))
+ runTestCases(t, false, "interfaces", interfacesTo(fullMaps))
+ runTestCases(t, false, "interfacesToPointer", interfacesTo(pointersTo(fullMaps)))
+}
+
+func TestMapContainers(t *testing.T) {
+ var (
+ nilMap map[int]interface{}
+ emptyMap = make(map[int]interface{})
+ fullMap = map[int]interface{}{0: nil}
+ )
+ runTestCases(t, false, "", []interface{}{
+ mapContainer{v: nilMap},
+ mapContainer{v: emptyMap},
+ mapContainer{v: fullMap},
+ mapPtrContainer{v: nil},
+ mapPtrContainer{v: &nilMap},
+ mapPtrContainer{v: &emptyMap},
+ mapPtrContainer{v: &fullMap},
+ })
+}
diff --git a/pkg/state/tests/register.go b/pkg/state/tests/register.go
new file mode 100644
index 000000000..074d86315
--- /dev/null
+++ b/pkg/state/tests/register.go
@@ -0,0 +1,21 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type alreadyRegisteredStruct struct{}
+
+// +stateify savable
+type alreadyRegisteredOther int
diff --git a/pkg/state/tests/register_test.go b/pkg/state/tests/register_test.go
new file mode 100644
index 000000000..c829753cc
--- /dev/null
+++ b/pkg/state/tests/register_test.go
@@ -0,0 +1,167 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+// faker calls itself whatever is in the name field.
+type faker struct {
+ Name string
+ Fields []string
+}
+
+func (f *faker) StateTypeName() string {
+ return f.Name
+}
+
+func (f *faker) StateFields() []string {
+ return f.Fields
+}
+
+// fakerWithSaverLoader has all it needs.
+type fakerWithSaverLoader struct {
+ faker
+}
+
+func (f *fakerWithSaverLoader) StateSave(m state.Sink) {}
+
+func (f *fakerWithSaverLoader) StateLoad(m state.Source) {}
+
+// fakerOther calls itself .. uh, itself?
+type fakerOther string
+
+func (f *fakerOther) StateTypeName() string {
+ return string(*f)
+}
+
+func (f *fakerOther) StateFields() []string {
+ return nil
+}
+
+func newFakerOther(name string) *fakerOther {
+ f := fakerOther(name)
+ return &f
+}
+
+// fakerOtherBadFields returns non-nil fields.
+type fakerOtherBadFields string
+
+func (f *fakerOtherBadFields) StateTypeName() string {
+ return string(*f)
+}
+
+func (f *fakerOtherBadFields) StateFields() []string {
+ return []string{string(*f)}
+}
+
+func newFakerOtherBadFields(name string) *fakerOtherBadFields {
+ f := fakerOtherBadFields(name)
+ return &f
+}
+
+// fakerOtherSaverLoader implements SaverLoader methods.
+type fakerOtherSaverLoader string
+
+func (f *fakerOtherSaverLoader) StateTypeName() string {
+ return string(*f)
+}
+
+func (f *fakerOtherSaverLoader) StateFields() []string {
+ return nil
+}
+
+func (f *fakerOtherSaverLoader) StateSave(m state.Sink) {}
+
+func (f *fakerOtherSaverLoader) StateLoad(m state.Source) {}
+
+func newFakerOtherSaverLoader(name string) *fakerOtherSaverLoader {
+ f := fakerOtherSaverLoader(name)
+ return &f
+}
+
+func TestRegisterPrimitives(t *testing.T) {
+ for _, typeName := range []string{
+ "int",
+ "int8",
+ "int16",
+ "int32",
+ "int64",
+ "uint",
+ "uintptr",
+ "uint8",
+ "uint16",
+ "uint32",
+ "uint64",
+ "float32",
+ "float64",
+ "complex64",
+ "complex128",
+ "string",
+ } {
+ t.Run("struct/"+typeName, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Registering type %q did not panic", typeName)
+ }
+ }()
+ state.Register(&faker{
+ Name: typeName,
+ })
+ })
+ t.Run("other/"+typeName, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Registering type %q did not panic", typeName)
+ }
+ }()
+ state.Register(newFakerOther(typeName))
+ })
+ }
+}
+
+func TestRegisterBad(t *testing.T) {
+ const (
+ goodName = "foo"
+ firstField = "a"
+ secondField = "b"
+ )
+ for name, object := range map[string]state.Type{
+ "non-struct-with-fields": newFakerOtherBadFields(goodName),
+ "non-struct-with-saverloader": newFakerOtherSaverLoader(goodName),
+ "struct-without-saverloader": &faker{Name: goodName},
+ "non-struct-duplicate-with-struct": newFakerOther((new(alreadyRegisteredStruct)).StateTypeName()),
+ "non-struct-duplicate-with-non-struct": newFakerOther((new(alreadyRegisteredOther)).StateTypeName()),
+ "struct-duplicate-with-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredStruct)).StateTypeName()}},
+ "struct-duplicate-with-non-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredOther)).StateTypeName()}},
+ "struct-with-empty-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{""}}},
+ "struct-with-empty-field-and-non-empty": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, ""}}},
+ "struct-with-duplicate-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, firstField}}},
+ "struct-with-duplicate-field-and-non-dup": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, secondField, firstField}}},
+ } {
+ t.Run(name, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Registering object %#v did not panic", object)
+ }
+ }()
+ state.Register(object)
+ })
+
+ }
+}
diff --git a/pkg/state/tests/string_test.go b/pkg/state/tests/string_test.go
new file mode 100644
index 000000000..44f5a562c
--- /dev/null
+++ b/pkg/state/tests/string_test.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+)
+
+const nonEmptyString = "hello world"
+
+var allStrings = []string{
+ "",
+ nonEmptyString,
+ "\\0",
+}
+
+func TestString(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(allStrings))
+ runTestCases(t, false, "pointers", pointersTo(flatten(allStrings)))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(allStrings)))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allStrings))))
+}
diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go
new file mode 100644
index 000000000..bd2c2b399
--- /dev/null
+++ b/pkg/state/tests/struct.go
@@ -0,0 +1,65 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+type unregisteredEmptyStruct struct{}
+
+// typeOnlyEmptyStruct just implements the state.Type interface.
+type typeOnlyEmptyStruct struct{}
+
+func (*typeOnlyEmptyStruct) StateTypeName() string { return "registeredEmptyStruct" }
+
+func (*typeOnlyEmptyStruct) StateFields() []string { return nil }
+
+// +stateify savable
+type savableEmptyStruct struct{}
+
+// +stateify savable
+type emptyStructPointer struct {
+ nothing *struct{}
+}
+
+// +stateify savable
+type outerSame struct {
+ inner inner
+}
+
+// +stateify savable
+type outerFieldFirst struct {
+ inner inner
+ v int64
+}
+
+// +stateify savable
+type outerFieldSecond struct {
+ v int64
+ inner inner
+}
+
+// +stateify savable
+type outerArray struct {
+ inner [2]inner
+}
+
+// +stateify savable
+type inner struct {
+ v int64
+}
+
+// +stateify savable
+type system struct {
+ v1 interface{}
+ v2 interface{}
+}
diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go
new file mode 100644
index 000000000..de9d17aa7
--- /dev/null
+++ b/pkg/state/tests/struct_test.go
@@ -0,0 +1,89 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func TestEmptyStruct(t *testing.T) {
+ runTestCases(t, false, "plain", []interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ savableEmptyStruct{},
+ })
+ runTestCases(t, false, "pointers", pointersTo([]interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ savableEmptyStruct{},
+ }))
+ runTestCases(t, false, "interfaces-pass", interfacesTo([]interface{}{
+ // Only registered types can be dispatched via interfaces. All
+ // other types should fail, even if it is the empty struct.
+ savableEmptyStruct{},
+ }))
+ runTestCases(t, true, "interfaces-fail", interfacesTo([]interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ }))
+ runTestCases(t, false, "interfacesToPointers-pass", interfacesTo(pointersTo([]interface{}{
+ savableEmptyStruct{},
+ })))
+ runTestCases(t, true, "interfacesToPointers-fail", interfacesTo(pointersTo([]interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ })))
+
+ // Ensuring empty struct aliasing works.
+ es := emptyStructPointer{new(struct{})}
+ runTestCases(t, false, "empty-struct-pointers", []interface{}{
+ emptyStructPointer{},
+ es,
+ []emptyStructPointer{es, es}, // Same pointer.
+ })
+}
+
+func TestRegisterTypeOnlyStruct(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Register did not panic")
+ }
+ }()
+ state.Register((*typeOnlyEmptyStruct)(nil))
+}
+
+func TestEmbeddedPointers(t *testing.T) {
+ var (
+ ofs outerSame
+ of1 outerFieldFirst
+ of2 outerFieldSecond
+ oa outerArray
+ )
+
+ runTestCases(t, false, "embedded-pointers", []interface{}{
+ system{&ofs, &ofs.inner},
+ system{&ofs.inner, &ofs},
+ system{&of1, &of1.inner},
+ system{&of1.inner, &of1},
+ system{&of2, &of2.inner},
+ system{&of2.inner, &of2},
+ system{&oa, &oa.inner[0]},
+ system{&oa, &oa.inner[1]},
+ system{&oa.inner[0], &oa},
+ system{&oa.inner[1], &oa},
+ })
+}
diff --git a/pkg/state/tests/tests.go b/pkg/state/tests/tests.go
new file mode 100644
index 000000000..435a0e9db
--- /dev/null
+++ b/pkg/state/tests/tests.go
@@ -0,0 +1,215 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tests tests the state packages.
+package tests
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "math"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/pretty"
+)
+
+// discard is an implementation of wire.Writer.
+type discard struct{}
+
+// Write implements wire.Writer.Write.
+func (discard) Write(p []byte) (int, error) { return len(p), nil }
+
+// WriteByte implements wire.Writer.WriteByte.
+func (discard) WriteByte(byte) error { return nil }
+
+// checkEqual checks if two objects are equal.
+//
+// N.B. This only handles one level of dereferences for NaN. Otherwise we
+// would need to fork the entire implementation of reflect.DeepEqual.
+func checkEqual(root, loadedValue interface{}) bool {
+ if reflect.DeepEqual(root, loadedValue) {
+ return true
+ }
+
+ // NaN is not equal to itself. We handle the case of raw floating point
+ // primitives here, but don't handle this case nested.
+ rf32, ok1 := root.(float32)
+ lf32, ok2 := loadedValue.(float32)
+ if ok1 && ok2 && math.IsNaN(float64(rf32)) && math.IsNaN(float64(lf32)) {
+ return true
+ }
+ rf64, ok1 := root.(float64)
+ lf64, ok2 := loadedValue.(float64)
+ if ok1 && ok2 && math.IsNaN(rf64) && math.IsNaN(lf64) {
+ return true
+ }
+
+ // Same real for complex numbers.
+ rc64, ok1 := root.(complex64)
+ lc64, ok2 := root.(complex64)
+ if ok1 && ok2 {
+ return checkEqual(real(rc64), real(lc64)) && checkEqual(imag(rc64), imag(lc64))
+ }
+ rc128, ok1 := root.(complex128)
+ lc128, ok2 := root.(complex128)
+ if ok1 && ok2 {
+ return checkEqual(real(rc128), real(lc128)) && checkEqual(imag(rc128), imag(lc128))
+ }
+
+ return false
+}
+
+// runTestCases runs a test for each object in objects.
+func runTestCases(t *testing.T, shouldFail bool, prefix string, objects []interface{}) {
+ t.Helper()
+ for i, root := range objects {
+ t.Run(fmt.Sprintf("%s%d", prefix, i), func(t *testing.T) {
+ t.Logf("Original object:\n%#v", root)
+
+ // Save the passed object.
+ saveBuffer := &bytes.Buffer{}
+ saveObjectPtr := reflect.New(reflect.TypeOf(root))
+ saveObjectPtr.Elem().Set(reflect.ValueOf(root))
+ saveStats, err := state.Save(context.Background(), saveBuffer, saveObjectPtr.Interface())
+ if err != nil {
+ if shouldFail {
+ return
+ }
+ t.Fatalf("Save failed unexpectedly: %v", err)
+ }
+
+ // Dump the serialized proto to aid with debugging.
+ var ppBuf bytes.Buffer
+ t.Logf("Raw state:\n%v", saveBuffer.Bytes())
+ if err := pretty.PrintText(&ppBuf, bytes.NewReader(saveBuffer.Bytes())); err != nil {
+ // We don't count this as a test failure if we
+ // have shouldFail set, but we will count as a
+ // failure if we were not expecting to fail.
+ if !shouldFail {
+ t.Errorf("PrettyPrint(html=false) failed unexpected: %v", err)
+ }
+ }
+ if err := pretty.PrintHTML(discard{}, bytes.NewReader(saveBuffer.Bytes())); err != nil {
+ // See above.
+ if !shouldFail {
+ t.Errorf("PrettyPrint(html=true) failed unexpected: %v", err)
+ }
+ }
+ t.Logf("Encoded state:\n%s", ppBuf.String())
+ t.Logf("Save stats:\n%s", saveStats.String())
+
+ // Load a new copy of the object.
+ loadObjectPtr := reflect.New(reflect.TypeOf(root))
+ loadStats, err := state.Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface())
+ if err != nil {
+ if shouldFail {
+ return
+ }
+ t.Fatalf("Load failed unexpectedly: %v", err)
+ }
+
+ // Compare the values.
+ loadedValue := loadObjectPtr.Elem().Interface()
+ if !checkEqual(root, loadedValue) {
+ if shouldFail {
+ return
+ }
+ t.Fatalf("Objects differ:\n\toriginal: %#v\n\tloaded: %#v\n", root, loadedValue)
+ }
+
+ // Everything went okay. Is that good?
+ if shouldFail {
+ t.Fatalf("This test was expected to fail, but didn't.")
+ }
+ t.Logf("Load stats:\n%s", loadStats.String())
+
+ // Truncate half the bytes in the byte stream,
+ // and ensure that we can't restore. Then
+ // truncate only the final byte and ensure that
+ // we can't restore.
+ l := saveBuffer.Len()
+ halfReader := bytes.NewReader(saveBuffer.Bytes()[:l/2])
+ if _, err := state.Load(context.Background(), halfReader, loadObjectPtr.Interface()); err == nil {
+ t.Errorf("Load with half bytes succeeded unexpectedly.")
+ }
+ missingByteReader := bytes.NewReader(saveBuffer.Bytes()[:l-1])
+ if _, err := state.Load(context.Background(), missingByteReader, loadObjectPtr.Interface()); err == nil {
+ t.Errorf("Load with missing byte succeeded unexpectedly.")
+ }
+ })
+ }
+}
+
+// convert converts the slice to an []interface{}.
+func convert(v interface{}) (r []interface{}) {
+ s := reflect.ValueOf(v) // Must be slice.
+ for i := 0; i < s.Len(); i++ {
+ r = append(r, s.Index(i).Interface())
+ }
+ return r
+}
+
+// flatten flattens multiple slices.
+func flatten(vs ...interface{}) (r []interface{}) {
+ for _, v := range vs {
+ r = append(r, convert(v)...)
+ }
+ return r
+}
+
+// filter maps from one slice to another.
+func filter(vs interface{}, fn func(interface{}) (interface{}, bool)) (r []interface{}) {
+ s := reflect.ValueOf(vs)
+ for i := 0; i < s.Len(); i++ {
+ v, ok := fn(s.Index(i).Interface())
+ if ok {
+ r = append(r, v)
+ }
+ }
+ return r
+}
+
+// combine combines objects in two slices as specified.
+func combine(v1, v2 interface{}, fn func(_, _ interface{}) interface{}) (r []interface{}) {
+ s1 := reflect.ValueOf(v1)
+ s2 := reflect.ValueOf(v2)
+ for i := 0; i < s1.Len(); i++ {
+ for j := 0; j < s2.Len(); j++ {
+ // Combine using the given function.
+ r = append(r, fn(s1.Index(i).Interface(), s2.Index(j).Interface()))
+ }
+ }
+ return r
+}
+
+// pointersTo is a filter function that returns pointers.
+func pointersTo(vs interface{}) []interface{} {
+ return filter(vs, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o))
+ v.Elem().Set(reflect.ValueOf(o))
+ return v.Interface(), true
+ })
+}
+
+// interfacesTo is a filter function that returns interface objects.
+func interfacesTo(vs interface{}) []interface{} {
+ return filter(vs, func(o interface{}) (interface{}, bool) {
+ var v [1]interface{}
+ v[0] = o
+ return v, true
+ })
+}
diff --git a/pkg/state/types.go b/pkg/state/types.go
new file mode 100644
index 000000000..215ef80f8
--- /dev/null
+++ b/pkg/state/types.go
@@ -0,0 +1,361 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "reflect"
+ "sort"
+
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// assertValidType asserts that the type is valid.
+func assertValidType(name string, fields []string) {
+ if name == "" {
+ Failf("type has empty name")
+ }
+ fieldsCopy := make([]string, len(fields))
+ for i := 0; i < len(fields); i++ {
+ if fields[i] == "" {
+ Failf("field has empty name for type %q", name)
+ }
+ fieldsCopy[i] = fields[i]
+ }
+ sort.Slice(fieldsCopy, func(i, j int) bool {
+ return fieldsCopy[i] < fieldsCopy[j]
+ })
+ for i := range fieldsCopy {
+ if i > 0 && fieldsCopy[i-1] == fieldsCopy[i] {
+ Failf("duplicate field %q for type %s", fieldsCopy[i], name)
+ }
+ }
+}
+
+// typeEntry is an entry in the typeDatabase.
+type typeEntry struct {
+ ID typeID
+ wire.Type
+}
+
+// reconciledTypeEntry is a reconciled entry in the typeDatabase.
+type reconciledTypeEntry struct {
+ wire.Type
+ LocalType reflect.Type
+ FieldOrder []int
+}
+
+// typeEncodeDatabase is an internal TypeInfo database for encoding.
+type typeEncodeDatabase struct {
+ // byType maps by type to the typeEntry.
+ byType map[reflect.Type]*typeEntry
+
+ // lastID is the last used ID.
+ lastID typeID
+}
+
+// makeTypeEncodeDatabase makes a typeDatabase.
+func makeTypeEncodeDatabase() typeEncodeDatabase {
+ return typeEncodeDatabase{
+ byType: make(map[reflect.Type]*typeEntry),
+ }
+}
+
+// typeDecodeDatabase is an internal TypeInfo database for decoding.
+type typeDecodeDatabase struct {
+ // byID maps by ID to type.
+ byID []*reconciledTypeEntry
+
+ // pending are entries that are pending validation by Lookup. These
+ // will be reconciled with actual objects. Note that these will also be
+ // used to lookup types by name, since they may not be reconciled and
+ // there's little value to deleting from this map.
+ pending []*wire.Type
+}
+
+// makeTypeDecodeDatabase makes a typeDatabase.
+func makeTypeDecodeDatabase() typeDecodeDatabase {
+ return typeDecodeDatabase{}
+}
+
+// lookupNameFields extracts the name and fields from an object.
+func lookupNameFields(typ reflect.Type) (string, []string, bool) {
+ v := reflect.Zero(reflect.PtrTo(typ)).Interface()
+ t, ok := v.(Type)
+ if !ok {
+ // Is this a primitive?
+ if typ.Kind() == reflect.Interface {
+ return interfaceType, nil, true
+ }
+ name := typ.Name()
+ if _, ok := primitiveTypeDatabase[name]; !ok {
+ // This is not a known type, and not a primitive. The
+ // encoder may proceed for anonymous empty structs, or
+ // it may deference the type pointer and try again.
+ return "", nil, false
+ }
+ return name, nil, true
+ }
+ // Extract the name from the object.
+ name := t.StateTypeName()
+ fields := t.StateFields()
+ assertValidType(name, fields)
+ return name, fields, true
+}
+
+// Lookup looks up or registers the given object.
+//
+// The bool indicates whether this is an existing entry: false means the entry
+// did not exist, and true means the entry did exist. If this bool is false and
+// the returned typeEntry are nil, then the obj did not implement the Type
+// interface.
+func (tdb *typeEncodeDatabase) Lookup(typ reflect.Type) (*typeEntry, bool) {
+ te, ok := tdb.byType[typ]
+ if !ok {
+ // Lookup the type information.
+ name, fields, ok := lookupNameFields(typ)
+ if !ok {
+ // Empty structs may still be encoded, so let the
+ // caller decide what to do from here.
+ return nil, false
+ }
+
+ // Register the new type.
+ tdb.lastID++
+ te = &typeEntry{
+ ID: tdb.lastID,
+ Type: wire.Type{
+ Name: name,
+ Fields: fields,
+ },
+ }
+
+ // All done.
+ tdb.byType[typ] = te
+ return te, false
+ }
+ return te, true
+}
+
+// Register adds a typeID entry.
+func (tbd *typeDecodeDatabase) Register(typ *wire.Type) {
+ assertValidType(typ.Name, typ.Fields)
+ tbd.pending = append(tbd.pending, typ)
+}
+
+// LookupName looks up the type name by ID.
+func (tbd *typeDecodeDatabase) LookupName(id typeID) string {
+ if len(tbd.pending) < int(id) {
+ // This is likely an encoder error?
+ Failf("type ID %d not available", id)
+ }
+ return tbd.pending[id-1].Name
+}
+
+// LookupType looks up the type by ID.
+func (tbd *typeDecodeDatabase) LookupType(id typeID) reflect.Type {
+ name := tbd.LookupName(id)
+ typ, ok := globalTypeDatabase[name]
+ if !ok {
+ // If not available, see if it's primitive.
+ typ, ok = primitiveTypeDatabase[name]
+ if !ok && name == interfaceType {
+ // Matches the built-in interface type.
+ var i interface{}
+ return reflect.TypeOf(&i).Elem()
+ }
+ if !ok {
+ // The type is perhaps not registered?
+ Failf("type name %q is not available", name)
+ }
+ return typ // Primitive type.
+ }
+ return typ // Registered type.
+}
+
+// singleFieldOrder defines the field order for a single field.
+var singleFieldOrder = []int{0}
+
+// Lookup looks up or registers the given object.
+//
+// First, the typeID is searched to see if this has already been appropriately
+// reconciled. If no, then a reconcilation will take place that may result in a
+// field ordering. If a nil reconciledTypeEntry is returned from this method,
+// then the object does not support the Type interface.
+//
+// This method never returns nil.
+func (tbd *typeDecodeDatabase) Lookup(id typeID, typ reflect.Type) *reconciledTypeEntry {
+ if len(tbd.byID) > int(id) && tbd.byID[id-1] != nil {
+ // Already reconciled.
+ return tbd.byID[id-1]
+ }
+ // The ID has not been reconciled yet. That's fine. We need to make
+ // sure it aligns with the current provided object.
+ if len(tbd.pending) < int(id) {
+ // This id was never registered. Probably an encoder error?
+ Failf("typeDatabase does not contain id %d", id)
+ }
+ // Extract the pending info.
+ pending := tbd.pending[id-1]
+ // Grow the byID list.
+ if len(tbd.byID) < int(id) {
+ tbd.byID = append(tbd.byID, make([]*reconciledTypeEntry, int(id)-len(tbd.byID))...)
+ }
+ // Reconcile the type.
+ name, fields, ok := lookupNameFields(typ)
+ if !ok {
+ // Empty structs are decoded only when the type is nil. Since
+ // this isn't the case, we fail here.
+ Failf("unsupported type %q during decode; can't reconcile", pending.Name)
+ }
+ if name != pending.Name {
+ // Are these the same type? Print a helpful message as this may
+ // actually happen in practice if types change.
+ Failf("typeDatabase contains conflicting definitions for id %d: %s->%v (current) and %s->%v (existing)",
+ id, name, fields, pending.Name, pending.Fields)
+ }
+ rte := &reconciledTypeEntry{
+ Type: wire.Type{
+ Name: name,
+ Fields: fields,
+ },
+ LocalType: typ,
+ }
+ // If there are zero or one fields, then we skip allocating the field
+ // slice. There is special handling for decoding in this case. If the
+ // field name does not match, it will be caught in the general purpose
+ // code below.
+ if len(fields) != len(pending.Fields) {
+ Failf("type %q contains different fields: %v (decode) and %v (encode)",
+ name, fields, pending.Fields)
+ }
+ if len(fields) == 0 {
+ tbd.byID[id-1] = rte // Save.
+ return rte
+ }
+ if len(fields) == 1 && fields[0] == pending.Fields[0] {
+ tbd.byID[id-1] = rte // Save.
+ rte.FieldOrder = singleFieldOrder
+ return rte
+ }
+ // For each field in the current object's information, match it to a
+ // field in the destination object. We know from the assertion above
+ // and the insertion on insertion to pending that neither field
+ // contains any duplicates.
+ fieldOrder := make([]int, len(fields))
+ for i, name := range fields {
+ fieldOrder[i] = -1 // Sentinel.
+ // Is it an exact match?
+ if pending.Fields[i] == name {
+ fieldOrder[i] = i
+ continue
+ }
+ // Find the matching field.
+ for j, otherName := range pending.Fields {
+ if name == otherName {
+ fieldOrder[i] = j
+ break
+ }
+ }
+ if fieldOrder[i] == -1 {
+ // The type name matches but we are lacking some common fields.
+ Failf("type %q has mismatched fields: %v (decode) and %v (encode)",
+ name, fields, pending.Fields)
+ }
+ }
+ // The type has been reeconciled.
+ rte.FieldOrder = fieldOrder
+ tbd.byID[id-1] = rte
+ return rte
+}
+
+// interfaceType defines all interfaces.
+const interfaceType = "interface"
+
+// primitiveTypeDatabase is a set of fixed types.
+var primitiveTypeDatabase = func() map[string]reflect.Type {
+ r := make(map[string]reflect.Type)
+ for _, t := range []reflect.Type{
+ reflect.TypeOf(false),
+ reflect.TypeOf(int(0)),
+ reflect.TypeOf(int8(0)),
+ reflect.TypeOf(int16(0)),
+ reflect.TypeOf(int32(0)),
+ reflect.TypeOf(int64(0)),
+ reflect.TypeOf(uint(0)),
+ reflect.TypeOf(uintptr(0)),
+ reflect.TypeOf(uint8(0)),
+ reflect.TypeOf(uint16(0)),
+ reflect.TypeOf(uint32(0)),
+ reflect.TypeOf(uint64(0)),
+ reflect.TypeOf(""),
+ reflect.TypeOf(float32(0.0)),
+ reflect.TypeOf(float64(0.0)),
+ reflect.TypeOf(complex64(0.0)),
+ reflect.TypeOf(complex128(0.0)),
+ } {
+ r[t.Name()] = t
+ }
+ return r
+}()
+
+// globalTypeDatabase is used for dispatching interfaces on decode.
+var globalTypeDatabase = map[string]reflect.Type{}
+
+// Register registers a type.
+//
+// This must be called on init and only done once.
+func Register(t Type) {
+ name := t.StateTypeName()
+ fields := t.StateFields()
+ assertValidType(name, fields)
+ // Register must always be called on pointers.
+ typ := reflect.TypeOf(t)
+ if typ.Kind() != reflect.Ptr {
+ Failf("Register must be called on pointers")
+ }
+ typ = typ.Elem()
+ if typ.Kind() == reflect.Struct {
+ // All registered structs must implement SaverLoader. We allow
+ // the registration is non-struct types with just the Type
+ // interface, but we need to call StateSave/StateLoad methods
+ // on aggregate types.
+ if _, ok := t.(SaverLoader); !ok {
+ Failf("struct %T does not implement SaverLoader", t)
+ }
+ } else {
+ // Non-structs must not have any fields. We don't support
+ // calling StateSave/StateLoad methods on any non-struct types.
+ // If custom behavior is required, these types should be
+ // wrapped in a structure of some kind.
+ if len(fields) != 0 {
+ Failf("non-struct %T has non-zero fields %v", t, fields)
+ }
+ // We don't allow non-structs to implement StateSave/StateLoad
+ // methods, because they won't be called and it's confusing.
+ if _, ok := t.(SaverLoader); ok {
+ Failf("non-struct %T implements SaverLoader", t)
+ }
+ }
+ if _, ok := primitiveTypeDatabase[name]; ok {
+ Failf("conflicting primitiveTypeDatabase entry for %T: used by primitive", t)
+ }
+ if _, ok := globalTypeDatabase[name]; ok {
+ Failf("conflicting globalTypeDatabase entries for %T: name conflict", t)
+ }
+ if name == interfaceType {
+ Failf("conflicting name for %T: matches interfaceType", t)
+ }
+ globalTypeDatabase[name] = typ
+}
diff --git a/pkg/state/wire/BUILD b/pkg/state/wire/BUILD
new file mode 100644
index 000000000..311b93dcb
--- /dev/null
+++ b/pkg/state/wire/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "wire",
+ srcs = ["wire.go"],
+ marshal = False,
+ stateify = False,
+ visibility = ["//:sandbox"],
+ deps = ["//pkg/gohacks"],
+)
diff --git a/pkg/state/wire/wire.go b/pkg/state/wire/wire.go
new file mode 100644
index 000000000..93dee6740
--- /dev/null
+++ b/pkg/state/wire/wire.go
@@ -0,0 +1,970 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package wire contains a few basic types that can be composed to serialize
+// graph information for the state package. This package defines the wire
+// protocol.
+//
+// Note that these types are careful about how they implement the relevant
+// interfaces (either value receiver or pointer receiver), so that native-sized
+// types, such as integers and simple pointers, can fit inside the interface
+// object.
+//
+// This package also uses panic as control flow, so called should be careful to
+// wrap calls in appropriate handlers.
+//
+// Testing for this package is driven by the state test package.
+package wire
+
+import (
+ "fmt"
+ "io"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
+)
+
+// Reader is the required reader interface.
+type Reader interface {
+ io.Reader
+ ReadByte() (byte, error)
+}
+
+// Writer is the required writer interface.
+type Writer interface {
+ io.Writer
+ WriteByte(byte) error
+}
+
+// readFull is a utility. The equivalent is not needed for Write, but the API
+// contract dictates that it must always complete all bytes given or return an
+// error.
+func readFull(r io.Reader, p []byte) {
+ for done := 0; done < len(p); {
+ n, err := r.Read(p[done:])
+ done += n
+ if n == 0 && err != nil {
+ panic(err)
+ }
+ }
+}
+
+// Object is a generic object.
+type Object interface {
+ // save saves the given object.
+ //
+ // Panic is used for error control flow.
+ save(Writer)
+
+ // load loads a new object of the given type.
+ //
+ // Panic is used for error control flow.
+ load(Reader) Object
+}
+
+// Bool is a boolean.
+type Bool bool
+
+// loadBool loads an object of type Bool.
+func loadBool(r Reader) Bool {
+ b := loadUint(r)
+ return Bool(b == 1)
+}
+
+// save implements Object.save.
+func (b Bool) save(w Writer) {
+ var v Uint
+ if b {
+ v = 1
+ } else {
+ v = 0
+ }
+ v.save(w)
+}
+
+// load implements Object.load.
+func (Bool) load(r Reader) Object { return loadBool(r) }
+
+// Int is a signed integer.
+//
+// This uses varint encoding.
+type Int int64
+
+// loadInt loads an object of type Int.
+func loadInt(r Reader) Int {
+ u := loadUint(r)
+ x := Int(u >> 1)
+ if u&1 != 0 {
+ x = ^x
+ }
+ return x
+}
+
+// save implements Object.save.
+func (i Int) save(w Writer) {
+ u := Uint(i) << 1
+ if i < 0 {
+ u = ^u
+ }
+ u.save(w)
+}
+
+// load implements Object.load.
+func (Int) load(r Reader) Object { return loadInt(r) }
+
+// Uint is an unsigned integer.
+type Uint uint64
+
+// loadUint loads an object of type Uint.
+func loadUint(r Reader) Uint {
+ var (
+ u Uint
+ s uint
+ )
+ for i := 0; i <= 9; i++ {
+ b, err := r.ReadByte()
+ if err != nil {
+ panic(err)
+ }
+ if b < 0x80 {
+ if i == 9 && b > 1 {
+ panic("overflow")
+ }
+ u |= Uint(b) << s
+ return u
+ }
+ u |= Uint(b&0x7f) << s
+ s += 7
+ }
+ panic("unreachable")
+}
+
+// save implements Object.save.
+func (u Uint) save(w Writer) {
+ for u >= 0x80 {
+ if err := w.WriteByte(byte(u) | 0x80); err != nil {
+ panic(err)
+ }
+ u >>= 7
+ }
+ if err := w.WriteByte(byte(u)); err != nil {
+ panic(err)
+ }
+}
+
+// load implements Object.load.
+func (Uint) load(r Reader) Object { return loadUint(r) }
+
+// Float32 is a 32-bit floating point number.
+type Float32 float32
+
+// loadFloat32 loads an object of type Float32.
+func loadFloat32(r Reader) Float32 {
+ n := loadUint(r)
+ return Float32(math.Float32frombits(uint32(n)))
+}
+
+// save implements Object.save.
+func (f Float32) save(w Writer) {
+ n := Uint(math.Float32bits(float32(f)))
+ n.save(w)
+}
+
+// load implements Object.load.
+func (Float32) load(r Reader) Object { return loadFloat32(r) }
+
+// Float64 is a 64-bit floating point number.
+type Float64 float64
+
+// loadFloat64 loads an object of type Float64.
+func loadFloat64(r Reader) Float64 {
+ n := loadUint(r)
+ return Float64(math.Float64frombits(uint64(n)))
+}
+
+// save implements Object.save.
+func (f Float64) save(w Writer) {
+ n := Uint(math.Float64bits(float64(f)))
+ n.save(w)
+}
+
+// load implements Object.load.
+func (Float64) load(r Reader) Object { return loadFloat64(r) }
+
+// Complex64 is a 64-bit complex number.
+type Complex64 complex128
+
+// loadComplex64 loads an object of type Complex64.
+func loadComplex64(r Reader) Complex64 {
+ re := loadFloat32(r)
+ im := loadFloat32(r)
+ return Complex64(complex(float32(re), float32(im)))
+}
+
+// save implements Object.save.
+func (c *Complex64) save(w Writer) {
+ re := Float32(real(*c))
+ im := Float32(imag(*c))
+ re.save(w)
+ im.save(w)
+}
+
+// load implements Object.load.
+func (*Complex64) load(r Reader) Object {
+ c := loadComplex64(r)
+ return &c
+}
+
+// Complex128 is a 128-bit complex number.
+type Complex128 complex128
+
+// loadComplex128 loads an object of type Complex128.
+func loadComplex128(r Reader) Complex128 {
+ re := loadFloat64(r)
+ im := loadFloat64(r)
+ return Complex128(complex(float64(re), float64(im)))
+}
+
+// save implements Object.save.
+func (c *Complex128) save(w Writer) {
+ re := Float64(real(*c))
+ im := Float64(imag(*c))
+ re.save(w)
+ im.save(w)
+}
+
+// load implements Object.load.
+func (*Complex128) load(r Reader) Object {
+ c := loadComplex128(r)
+ return &c
+}
+
+// String is a string.
+type String string
+
+// loadString loads an object of type String.
+func loadString(r Reader) String {
+ l := loadUint(r)
+ p := make([]byte, l)
+ readFull(r, p)
+ return String(gohacks.StringFromImmutableBytes(p))
+}
+
+// save implements Object.save.
+func (s *String) save(w Writer) {
+ l := Uint(len(*s))
+ l.save(w)
+ p := gohacks.ImmutableBytesFromString(string(*s))
+ _, err := w.Write(p) // Must write all bytes.
+ if err != nil {
+ panic(err)
+ }
+}
+
+// load implements Object.load.
+func (*String) load(r Reader) Object {
+ s := loadString(r)
+ return &s
+}
+
+// Dot is a kind of reference: one of Index and FieldName.
+type Dot interface {
+ isDot()
+}
+
+// Index is a reference resolution.
+type Index uint32
+
+func (Index) isDot() {}
+
+// FieldName is a reference resolution.
+type FieldName string
+
+func (*FieldName) isDot() {}
+
+// Ref is a reference to an object.
+type Ref struct {
+ // Root is the root object.
+ Root Uint
+
+ // Dots is the set of traversals required from the Root object above.
+ // Note that this will be stored in reverse order for efficiency.
+ Dots []Dot
+
+ // Type is the base type for the root object. This is non-nil iff Dots
+ // is non-zero length (that is, this is a complex reference). This is
+ // not *strictly* necessary, but can be used to simplify decoding.
+ Type TypeSpec
+}
+
+// loadRef loads an object of type Ref (abstract).
+func loadRef(r Reader) Ref {
+ ref := Ref{
+ Root: loadUint(r),
+ }
+ l := loadUint(r)
+ ref.Dots = make([]Dot, l)
+ for i := 0; i < int(l); i++ {
+ // Disambiguate between an Index (non-negative) and a field
+ // name (negative). This does some space and avoids a dedicate
+ // loadDot function. See Ref.save for the other side.
+ d := loadInt(r)
+ if d >= 0 {
+ ref.Dots[i] = Index(d)
+ continue
+ }
+ p := make([]byte, -d)
+ readFull(r, p)
+ fieldName := FieldName(gohacks.StringFromImmutableBytes(p))
+ ref.Dots[i] = &fieldName
+ }
+ if l != 0 {
+ // Only if dots is non-zero.
+ ref.Type = loadTypeSpec(r)
+ }
+ return ref
+}
+
+// save implements Object.save.
+func (r *Ref) save(w Writer) {
+ r.Root.save(w)
+ l := Uint(len(r.Dots))
+ l.save(w)
+ for _, d := range r.Dots {
+ // See LoadRef. We use non-negative numbers to encode Index
+ // objects and negative numbers to encode field lengths.
+ switch x := d.(type) {
+ case Index:
+ i := Int(x)
+ i.save(w)
+ case *FieldName:
+ d := Int(-len(*x))
+ d.save(w)
+ p := gohacks.ImmutableBytesFromString(string(*x))
+ if _, err := w.Write(p); err != nil {
+ panic(err)
+ }
+ default:
+ panic("unknown dot implementation")
+ }
+ }
+ if l != 0 {
+ // See above.
+ saveTypeSpec(w, r.Type)
+ }
+}
+
+// load implements Object.load.
+func (*Ref) load(r Reader) Object {
+ ref := loadRef(r)
+ return &ref
+}
+
+// Nil is a primitive zero value of any type.
+type Nil struct{}
+
+// loadNil loads an object of type Nil.
+func loadNil(r Reader) Nil {
+ return Nil{}
+}
+
+// save implements Object.save.
+func (Nil) save(w Writer) {}
+
+// load implements Object.load.
+func (Nil) load(r Reader) Object { return loadNil(r) }
+
+// Slice is a slice value.
+type Slice struct {
+ Length Uint
+ Capacity Uint
+ Ref Ref
+}
+
+// loadSlice loads an object of type Slice.
+func loadSlice(r Reader) Slice {
+ return Slice{
+ Length: loadUint(r),
+ Capacity: loadUint(r),
+ Ref: loadRef(r),
+ }
+}
+
+// save implements Object.save.
+func (s *Slice) save(w Writer) {
+ s.Length.save(w)
+ s.Capacity.save(w)
+ s.Ref.save(w)
+}
+
+// load implements Object.load.
+func (*Slice) load(r Reader) Object {
+ s := loadSlice(r)
+ return &s
+}
+
+// Array is an array value.
+type Array struct {
+ Contents []Object
+}
+
+// loadArray loads an object of type Array.
+func loadArray(r Reader) Array {
+ l := loadUint(r)
+ if l == 0 {
+ // Note that there isn't a single object available to encode
+ // the type of, so we need this additional branch.
+ return Array{}
+ }
+ // All the objects here have the same type, so use dynamic dispatch
+ // only once. All other objects will automatically take the same type
+ // as the first object.
+ contents := make([]Object, l)
+ v := Load(r)
+ contents[0] = v
+ for i := 1; i < int(l); i++ {
+ contents[i] = v.load(r)
+ }
+ return Array{
+ Contents: contents,
+ }
+}
+
+// save implements Object.save.
+func (a *Array) save(w Writer) {
+ l := Uint(len(a.Contents))
+ l.save(w)
+ if l == 0 {
+ // See LoadArray.
+ return
+ }
+ // See above.
+ Save(w, a.Contents[0])
+ for i := 1; i < int(l); i++ {
+ a.Contents[i].save(w)
+ }
+}
+
+// load implements Object.load.
+func (*Array) load(r Reader) Object {
+ a := loadArray(r)
+ return &a
+}
+
+// Map is a map value.
+type Map struct {
+ Keys []Object
+ Values []Object
+}
+
+// loadMap loads an object of type Map.
+func loadMap(r Reader) Map {
+ l := loadUint(r)
+ if l == 0 {
+ // See LoadArray.
+ return Map{}
+ }
+ // See type dispatch notes in Array.
+ keys := make([]Object, l)
+ values := make([]Object, l)
+ k := Load(r)
+ v := Load(r)
+ keys[0] = k
+ values[0] = v
+ for i := 1; i < int(l); i++ {
+ keys[i] = k.load(r)
+ values[i] = v.load(r)
+ }
+ return Map{
+ Keys: keys,
+ Values: values,
+ }
+}
+
+// save implements Object.save.
+func (m *Map) save(w Writer) {
+ l := Uint(len(m.Keys))
+ if int(l) != len(m.Values) {
+ panic(fmt.Sprintf("mismatched keys (%d) Aand values (%d)", len(m.Keys), len(m.Values)))
+ }
+ l.save(w)
+ if l == 0 {
+ // See LoadArray.
+ return
+ }
+ // See above.
+ Save(w, m.Keys[0])
+ Save(w, m.Values[0])
+ for i := 1; i < int(l); i++ {
+ m.Keys[i].save(w)
+ m.Values[i].save(w)
+ }
+}
+
+// load implements Object.load.
+func (*Map) load(r Reader) Object {
+ m := loadMap(r)
+ return &m
+}
+
+// TypeSpec is a type dereference.
+type TypeSpec interface {
+ isTypeSpec()
+}
+
+// TypeID is a concrete type ID.
+type TypeID Uint
+
+func (TypeID) isTypeSpec() {}
+
+// TypeSpecPointer is a pointer type.
+type TypeSpecPointer struct {
+ Type TypeSpec
+}
+
+func (*TypeSpecPointer) isTypeSpec() {}
+
+// TypeSpecArray is an array type.
+type TypeSpecArray struct {
+ Count Uint
+ Type TypeSpec
+}
+
+func (*TypeSpecArray) isTypeSpec() {}
+
+// TypeSpecSlice is a slice type.
+type TypeSpecSlice struct {
+ Type TypeSpec
+}
+
+func (*TypeSpecSlice) isTypeSpec() {}
+
+// TypeSpecMap is a map type.
+type TypeSpecMap struct {
+ Key TypeSpec
+ Value TypeSpec
+}
+
+func (*TypeSpecMap) isTypeSpec() {}
+
+// TypeSpecNil is an empty type.
+type TypeSpecNil struct{}
+
+func (TypeSpecNil) isTypeSpec() {}
+
+// TypeSpec types.
+//
+// These use a distinct encoding on the wire, as they are used only in the
+// interface object. They are decoded through the dedicated loadTypeSpec and
+// saveTypeSpec functions.
+const (
+ typeSpecTypeID Uint = iota
+ typeSpecPointer
+ typeSpecArray
+ typeSpecSlice
+ typeSpecMap
+ typeSpecNil
+)
+
+// loadTypeSpec loads TypeSpec values.
+func loadTypeSpec(r Reader) TypeSpec {
+ switch hdr := loadUint(r); hdr {
+ case typeSpecTypeID:
+ return TypeID(loadUint(r))
+ case typeSpecPointer:
+ return &TypeSpecPointer{
+ Type: loadTypeSpec(r),
+ }
+ case typeSpecArray:
+ return &TypeSpecArray{
+ Count: loadUint(r),
+ Type: loadTypeSpec(r),
+ }
+ case typeSpecSlice:
+ return &TypeSpecSlice{
+ Type: loadTypeSpec(r),
+ }
+ case typeSpecMap:
+ return &TypeSpecMap{
+ Key: loadTypeSpec(r),
+ Value: loadTypeSpec(r),
+ }
+ case typeSpecNil:
+ return TypeSpecNil{}
+ default:
+ // This is not a valid stream?
+ panic(fmt.Errorf("unknown header: %d", hdr))
+ }
+}
+
+// saveTypeSpec saves TypeSpec values.
+func saveTypeSpec(w Writer, t TypeSpec) {
+ switch x := t.(type) {
+ case TypeID:
+ typeSpecTypeID.save(w)
+ Uint(x).save(w)
+ case *TypeSpecPointer:
+ typeSpecPointer.save(w)
+ saveTypeSpec(w, x.Type)
+ case *TypeSpecArray:
+ typeSpecArray.save(w)
+ x.Count.save(w)
+ saveTypeSpec(w, x.Type)
+ case *TypeSpecSlice:
+ typeSpecSlice.save(w)
+ saveTypeSpec(w, x.Type)
+ case *TypeSpecMap:
+ typeSpecMap.save(w)
+ saveTypeSpec(w, x.Key)
+ saveTypeSpec(w, x.Value)
+ case TypeSpecNil:
+ typeSpecNil.save(w)
+ default:
+ // This should not happen?
+ panic(fmt.Errorf("unknown type %T", t))
+ }
+}
+
+// Interface is an interface value.
+type Interface struct {
+ Type TypeSpec
+ Value Object
+}
+
+// loadInterface loads an object of type Interface.
+func loadInterface(r Reader) Interface {
+ return Interface{
+ Type: loadTypeSpec(r),
+ Value: Load(r),
+ }
+}
+
+// save implements Object.save.
+func (i *Interface) save(w Writer) {
+ saveTypeSpec(w, i.Type)
+ Save(w, i.Value)
+}
+
+// load implements Object.load.
+func (*Interface) load(r Reader) Object {
+ i := loadInterface(r)
+ return &i
+}
+
+// Type is type information.
+type Type struct {
+ Name string
+ Fields []string
+}
+
+// loadType loads an object of type Type.
+func loadType(r Reader) Type {
+ name := string(loadString(r))
+ l := loadUint(r)
+ fields := make([]string, l)
+ for i := 0; i < int(l); i++ {
+ fields[i] = string(loadString(r))
+ }
+ return Type{
+ Name: name,
+ Fields: fields,
+ }
+}
+
+// save implements Object.save.
+func (t *Type) save(w Writer) {
+ s := String(t.Name)
+ s.save(w)
+ l := Uint(len(t.Fields))
+ l.save(w)
+ for i := 0; i < int(l); i++ {
+ s := String(t.Fields[i])
+ s.save(w)
+ }
+}
+
+// load implements Object.load.
+func (*Type) load(r Reader) Object {
+ t := loadType(r)
+ return &t
+}
+
+// multipleObjects is a special type for serializing multiple objects.
+type multipleObjects []Object
+
+// loadMultipleObjects loads a series of objects.
+func loadMultipleObjects(r Reader) multipleObjects {
+ l := loadUint(r)
+ m := make(multipleObjects, l)
+ for i := 0; i < int(l); i++ {
+ m[i] = Load(r)
+ }
+ return m
+}
+
+// save implements Object.save.
+func (m *multipleObjects) save(w Writer) {
+ l := Uint(len(*m))
+ l.save(w)
+ for i := 0; i < int(l); i++ {
+ Save(w, (*m)[i])
+ }
+}
+
+// load implements Object.load.
+func (*multipleObjects) load(r Reader) Object {
+ m := loadMultipleObjects(r)
+ return &m
+}
+
+// noObjects represents no objects.
+type noObjects struct{}
+
+// loadNoObjects loads a sentinel.
+func loadNoObjects(r Reader) noObjects { return noObjects{} }
+
+// save implements Object.save.
+func (noObjects) save(w Writer) {}
+
+// load implements Object.load.
+func (noObjects) load(r Reader) Object { return loadNoObjects(r) }
+
+// Struct is a basic composite value.
+type Struct struct {
+ TypeID TypeID
+ fields Object // Optionally noObjects or *multipleObjects.
+}
+
+// Field returns a pointer to the given field slot.
+//
+// This must be called after Alloc.
+func (s *Struct) Field(i int) *Object {
+ if fields, ok := s.fields.(*multipleObjects); ok {
+ return &((*fields)[i])
+ }
+ if _, ok := s.fields.(noObjects); ok {
+ // Alloc may be optionally called; can't call twice.
+ panic("Field called inappropriately, wrong Alloc?")
+ }
+ return &s.fields
+}
+
+// Alloc allocates the given number of fields.
+//
+// This must be called before Add and Save.
+//
+// Precondition: slots must be positive.
+func (s *Struct) Alloc(slots int) {
+ switch {
+ case slots == 0:
+ s.fields = noObjects{}
+ case slots == 1:
+ // Leave it alone.
+ case slots > 1:
+ fields := make(multipleObjects, slots)
+ s.fields = &fields
+ default:
+ // Violates precondition.
+ panic(fmt.Sprintf("Alloc called with negative slots %d?", slots))
+ }
+}
+
+// Fields returns the number of fields.
+func (s *Struct) Fields() int {
+ switch x := s.fields.(type) {
+ case *multipleObjects:
+ return len(*x)
+ case noObjects:
+ return 0
+ default:
+ return 1
+ }
+}
+
+// loadStruct loads an object of type Struct.
+func loadStruct(r Reader) Struct {
+ return Struct{
+ TypeID: TypeID(loadUint(r)),
+ fields: Load(r),
+ }
+}
+
+// save implements Object.save.
+//
+// Precondition: Alloc must have been called, and the fields all filled in
+// appropriately. See Alloc and Add for more details.
+func (s *Struct) save(w Writer) {
+ Uint(s.TypeID).save(w)
+ Save(w, s.fields)
+}
+
+// load implements Object.load.
+func (*Struct) load(r Reader) Object {
+ s := loadStruct(r)
+ return &s
+}
+
+// Object types.
+//
+// N.B. Be careful about changing the order or introducing new elements in the
+// middle here. This is part of the wire format and shouldn't change.
+const (
+ typeBool Uint = iota
+ typeInt
+ typeUint
+ typeFloat32
+ typeFloat64
+ typeNil
+ typeRef
+ typeString
+ typeSlice
+ typeArray
+ typeMap
+ typeStruct
+ typeNoObjects
+ typeMultipleObjects
+ typeInterface
+ typeComplex64
+ typeComplex128
+ typeType
+)
+
+// Save saves the given object.
+//
+// +checkescape all
+//
+// N.B. This function will panic on error.
+func Save(w Writer, obj Object) {
+ switch x := obj.(type) {
+ case Bool:
+ typeBool.save(w)
+ x.save(w)
+ case Int:
+ typeInt.save(w)
+ x.save(w)
+ case Uint:
+ typeUint.save(w)
+ x.save(w)
+ case Float32:
+ typeFloat32.save(w)
+ x.save(w)
+ case Float64:
+ typeFloat64.save(w)
+ x.save(w)
+ case Nil:
+ typeNil.save(w)
+ x.save(w)
+ case *Ref:
+ typeRef.save(w)
+ x.save(w)
+ case *String:
+ typeString.save(w)
+ x.save(w)
+ case *Slice:
+ typeSlice.save(w)
+ x.save(w)
+ case *Array:
+ typeArray.save(w)
+ x.save(w)
+ case *Map:
+ typeMap.save(w)
+ x.save(w)
+ case *Struct:
+ typeStruct.save(w)
+ x.save(w)
+ case noObjects:
+ typeNoObjects.save(w)
+ x.save(w)
+ case *multipleObjects:
+ typeMultipleObjects.save(w)
+ x.save(w)
+ case *Interface:
+ typeInterface.save(w)
+ x.save(w)
+ case *Type:
+ typeType.save(w)
+ x.save(w)
+ case *Complex64:
+ typeComplex64.save(w)
+ x.save(w)
+ case *Complex128:
+ typeComplex128.save(w)
+ x.save(w)
+ default:
+ panic(fmt.Errorf("unknown type: %#v", obj))
+ }
+}
+
+// Load loads a new object.
+//
+// +checkescape all
+//
+// N.B. This function will panic on error.
+func Load(r Reader) Object {
+ switch hdr := loadUint(r); hdr {
+ case typeBool:
+ return loadBool(r)
+ case typeInt:
+ return loadInt(r)
+ case typeUint:
+ return loadUint(r)
+ case typeFloat32:
+ return loadFloat32(r)
+ case typeFloat64:
+ return loadFloat64(r)
+ case typeNil:
+ return loadNil(r)
+ case typeRef:
+ return ((*Ref)(nil)).load(r) // Escapes.
+ case typeString:
+ return ((*String)(nil)).load(r) // Escapes.
+ case typeSlice:
+ return ((*Slice)(nil)).load(r) // Escapes.
+ case typeArray:
+ return ((*Array)(nil)).load(r) // Escapes.
+ case typeMap:
+ return ((*Map)(nil)).load(r) // Escapes.
+ case typeStruct:
+ return ((*Struct)(nil)).load(r) // Escapes.
+ case typeNoObjects: // Special for struct.
+ return loadNoObjects(r)
+ case typeMultipleObjects: // Special for struct.
+ return ((*multipleObjects)(nil)).load(r) // Escapes.
+ case typeInterface:
+ return ((*Interface)(nil)).load(r) // Escapes.
+ case typeComplex64:
+ return ((*Complex64)(nil)).load(r) // Escapes.
+ case typeComplex128:
+ return ((*Complex128)(nil)).load(r) // Escapes.
+ case typeType:
+ return ((*Type)(nil)).load(r) // Escapes.
+ default:
+ // This is not a valid stream?
+ panic(fmt.Errorf("unknown header: %d", hdr))
+ }
+}
+
+// LoadUint loads a single unsigned integer.
+//
+// N.B. This function will panic on error.
+func LoadUint(r Reader) uint64 {
+ return uint64(loadUint(r))
+}
+
+// SaveUint saves a single unsigned integer.
+//
+// N.B. This function will panic on error.
+func SaveUint(w Writer, v uint64) {
+ Uint(v).save(w)
+}