summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/compressio/compressio.go54
-rw-r--r--pkg/gohacks/BUILD1
-rw-r--r--pkg/ilist/list.go6
-rw-r--r--pkg/sentry/kernel/BUILD1
-rw-r--r--pkg/sentry/kernel/kernel.go22
-rw-r--r--pkg/sentry/pgalloc/BUILD1
-rw-r--r--pkg/sentry/pgalloc/save_restore.go13
-rw-r--r--pkg/state/BUILD68
-rw-r--r--pkg/state/README.md158
-rw-r--r--pkg/state/decode.go918
-rw-r--r--pkg/state/decode_unsafe.go27
-rw-r--r--pkg/state/encode.go1025
-rw-r--r--pkg/state/encode_unsafe.go48
-rw-r--r--pkg/state/map.go232
-rw-r--r--pkg/state/object.proto140
-rw-r--r--pkg/state/pretty/BUILD13
-rw-r--r--pkg/state/pretty/pretty.go273
-rw-r--r--pkg/state/printer.go251
-rw-r--r--pkg/state/state.go360
-rw-r--r--pkg/state/state_norace.go19
-rw-r--r--pkg/state/state_race.go19
-rw-r--r--pkg/state/state_test.go721
-rw-r--r--pkg/state/statefile/BUILD1
-rw-r--r--pkg/state/statefile/statefile.go15
-rw-r--r--pkg/state/stats.go117
-rw-r--r--pkg/state/tests/BUILD43
-rw-r--r--pkg/state/tests/array.go35
-rw-r--r--pkg/state/tests/array_test.go134
-rw-r--r--pkg/state/tests/bench.go24
-rw-r--r--pkg/state/tests/bench_test.go153
-rw-r--r--pkg/state/tests/bool_test.go31
-rw-r--r--pkg/state/tests/float_test.go118
-rw-r--r--pkg/state/tests/integer.go163
-rw-r--r--pkg/state/tests/integer_test.go94
-rw-r--r--pkg/state/tests/load.go61
-rw-r--r--pkg/state/tests/load_test.go70
-rw-r--r--pkg/state/tests/map.go28
-rw-r--r--pkg/state/tests/map_test.go90
-rw-r--r--pkg/state/tests/register.go21
-rw-r--r--pkg/state/tests/register_test.go167
-rw-r--r--pkg/state/tests/string_test.go34
-rw-r--r--pkg/state/tests/struct.go65
-rw-r--r--pkg/state/tests/struct_test.go89
-rw-r--r--pkg/state/tests/tests.go215
-rw-r--r--pkg/state/types.go361
-rw-r--r--pkg/state/wire/BUILD12
-rw-r--r--pkg/state/wire/wire.go970
-rw-r--r--runsc/cmd/BUILD2
-rw-r--r--runsc/cmd/statefile.go12
-rw-r--r--tools/checkescape/checkescape.go4
-rw-r--r--tools/go_stateify/main.go182
51 files changed, 5171 insertions, 2510 deletions
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go
index 5f52cbe74..b094c5662 100644
--- a/pkg/compressio/compressio.go
+++ b/pkg/compressio/compressio.go
@@ -346,20 +346,22 @@ func (p *pool) schedule(c *chunk, callback func(*chunk) error) error {
}
}
-// reader chunks reads and decompresses.
-type reader struct {
+// Reader is a compressed reader.
+type Reader struct {
pool
// in is the source.
in io.Reader
}
+var _ io.Reader = (*Reader)(nil)
+
// NewReader returns a new compressed reader. If key is non-nil, the data stream
// is assumed to contain expected hash values, which will be compared against
// hash values computed from the compressed bytes. See package comments for
// details.
-func NewReader(in io.Reader, key []byte) (io.Reader, error) {
- r := &reader{
+func NewReader(in io.Reader, key []byte) (*Reader, error) {
+ r := &Reader{
in: in,
}
@@ -394,8 +396,19 @@ var errNewBuffer = errors.New("buffer ready")
// ErrHashMismatch is returned if the hash does not match.
var ErrHashMismatch = errors.New("hash mismatch")
+// ReadByte implements wire.Reader.ReadByte.
+func (r *Reader) ReadByte() (byte, error) {
+ var p [1]byte
+ n, err := r.Read(p[:])
+ if n != 1 {
+ return p[0], err
+ }
+ // Suppress EOF.
+ return p[0], nil
+}
+
// Read implements io.Reader.Read.
-func (r *reader) Read(p []byte) (int, error) {
+func (r *Reader) Read(p []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
@@ -551,8 +564,8 @@ func (r *reader) Read(p []byte) (int, error) {
return done, nil
}
-// writer chunks and schedules writes.
-type writer struct {
+// Writer is a compressed writer.
+type Writer struct {
pool
// out is the underlying writer.
@@ -562,6 +575,8 @@ type writer struct {
closed bool
}
+var _ io.Writer = (*Writer)(nil)
+
// NewWriter returns a new compressed writer. If key is non-nil, hash values are
// generated and written out for compressed bytes. See package comments for
// details.
@@ -569,8 +584,8 @@ type writer struct {
// The recommended chunkSize is on the order of 1M. Extra memory may be
// buffered (in the form of read-ahead, or buffered writes), and is limited to
// O(chunkSize * [1+GOMAXPROCS]).
-func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.WriteCloser, error) {
- w := &writer{
+func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer, error) {
+ w := &Writer{
pool: pool{
chunkSize: chunkSize,
buf: bufPool.Get().(*bytes.Buffer),
@@ -597,7 +612,7 @@ func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.Write
}
// flush writes a single buffer.
-func (w *writer) flush(c *chunk) error {
+func (w *Writer) flush(c *chunk) error {
// Prefix each chunk with a length; this allows the reader to safely
// limit reads while buffering.
l := uint32(c.compressed.Len())
@@ -624,8 +639,23 @@ func (w *writer) flush(c *chunk) error {
return nil
}
+// WriteByte implements wire.Writer.WriteByte.
+//
+// Note that this implementation is necessary on the object itself, as an
+// interface-based dispatch cannot tell whether the array backing the slice
+// escapes, therefore the all bytes written will generate an escape.
+func (w *Writer) WriteByte(b byte) error {
+ var p [1]byte
+ p[0] = b
+ n, err := w.Write(p[:])
+ if n != 1 {
+ return err
+ }
+ return nil
+}
+
// Write implements io.Writer.Write.
-func (w *writer) Write(p []byte) (int, error) {
+func (w *Writer) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
@@ -710,7 +740,7 @@ func (w *writer) Write(p []byte) (int, error) {
}
// Close implements io.Closer.Close.
-func (w *writer) Close() error {
+func (w *Writer) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD
index 798a65eca..35683fe98 100644
--- a/pkg/gohacks/BUILD
+++ b/pkg/gohacks/BUILD
@@ -7,5 +7,6 @@ go_library(
srcs = [
"gohacks_unsafe.go",
],
+ stateify = False,
visibility = ["//:sandbox"],
)
diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go
index 0d07da3b1..f4a4c33d3 100644
--- a/pkg/ilist/list.go
+++ b/pkg/ilist/list.go
@@ -90,7 +90,7 @@ func (l *List) Back() Element {
//
// NOTE: This is an O(n) operation.
func (l *List) Len() (count int) {
- for e := l.Front(); e != nil; e = e.Next() {
+ for e := l.Front(); e != nil; e = (ElementMapper{}.linkerFor(e)).Next() {
count++
}
return count
@@ -182,13 +182,13 @@ func (l *List) Remove(e Element) {
if prev != nil {
ElementMapper{}.linkerFor(prev).SetNext(next)
- } else {
+ } else if l.head == e {
l.head = next
}
if next != nil {
ElementMapper{}.linkerFor(next).SetPrev(prev)
- } else {
+ } else if l.tail == e {
l.tail = prev
}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 1510a7c26..25fe1921b 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -200,6 +200,7 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/state",
"//pkg/state/statefile",
+ "//pkg/state/wire",
"//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 554a42e05..2177b785a 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -34,7 +34,6 @@ package kernel
import (
"errors"
"fmt"
- "io"
"path/filepath"
"sync/atomic"
"time"
@@ -73,6 +72,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -417,7 +417,7 @@ func (k *Kernel) Init(args InitKernelArgs) error {
// SaveTo saves the state of k to w.
//
// Preconditions: The kernel must be paused throughout the call to SaveTo.
-func (k *Kernel) SaveTo(w io.Writer) error {
+func (k *Kernel) SaveTo(w wire.Writer) error {
saveStart := time.Now()
ctx := k.SupervisorContext()
@@ -473,18 +473,18 @@ func (k *Kernel) SaveTo(w io.Writer) error {
//
// N.B. This will also be saved along with the full kernel save below.
cpuidStart := time.Now()
- if err := state.Save(k.SupervisorContext(), w, k.FeatureSet(), nil); err != nil {
+ if _, err := state.Save(k.SupervisorContext(), w, k.FeatureSet()); err != nil {
return err
}
log.Infof("CPUID save took [%s].", time.Since(cpuidStart))
// Save the kernel state.
kernelStart := time.Now()
- var stats state.Stats
- if err := state.Save(k.SupervisorContext(), w, k, &stats); err != nil {
+ stats, err := state.Save(k.SupervisorContext(), w, k)
+ if err != nil {
return err
}
- log.Infof("Kernel save stats: %s", &stats)
+ log.Infof("Kernel save stats: %s", stats.String())
log.Infof("Kernel save took [%s].", time.Since(kernelStart))
// Save the memory file's state.
@@ -629,7 +629,7 @@ func (ts *TaskSet) unregisterEpollWaiters() {
}
// LoadFrom returns a new Kernel loaded from args.
-func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
+func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
loadStart := time.Now()
initAppCores := k.applicationCores
@@ -640,7 +640,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
// don't need to explicitly install it in the Kernel.
cpuidStart := time.Now()
var features cpuid.FeatureSet
- if err := state.Load(k.SupervisorContext(), r, &features, nil); err != nil {
+ if _, err := state.Load(k.SupervisorContext(), r, &features); err != nil {
return err
}
log.Infof("CPUID load took [%s].", time.Since(cpuidStart))
@@ -655,11 +655,11 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
// Load the kernel state.
kernelStart := time.Now()
- var stats state.Stats
- if err := state.Load(k.SupervisorContext(), r, k, &stats); err != nil {
+ stats, err := state.Load(k.SupervisorContext(), r, k)
+ if err != nil {
return err
}
- log.Infof("Kernel load stats: %s", &stats)
+ log.Infof("Kernel load stats: %s", stats.String())
log.Infof("Kernel load took [%s].", time.Since(kernelStart))
// rootNetworkNamespace should be populated after loading the state file.
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index a9836ba71..e1fcb175f 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -92,6 +92,7 @@ go_library(
"//pkg/sentry/platform",
"//pkg/sentry/usage",
"//pkg/state",
+ "//pkg/state/wire",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go
index f8385c146..78317fa35 100644
--- a/pkg/sentry/pgalloc/save_restore.go
+++ b/pkg/sentry/pgalloc/save_restore.go
@@ -26,11 +26,12 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
"gvisor.dev/gvisor/pkg/usermem"
)
// SaveTo writes f's state to the given stream.
-func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error {
+func (f *MemoryFile) SaveTo(ctx context.Context, w wire.Writer) error {
// Wait for reclaim.
f.mu.Lock()
defer f.mu.Unlock()
@@ -79,10 +80,10 @@ func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error {
}
// Save metadata.
- if err := state.Save(ctx, w, &f.fileSize, nil); err != nil {
+ if _, err := state.Save(ctx, w, &f.fileSize); err != nil {
return err
}
- if err := state.Save(ctx, w, &f.usage, nil); err != nil {
+ if _, err := state.Save(ctx, w, &f.usage); err != nil {
return err
}
@@ -115,9 +116,9 @@ func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error {
}
// LoadFrom loads MemoryFile state from the given stream.
-func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error {
+func (f *MemoryFile) LoadFrom(ctx context.Context, r wire.Reader) error {
// Load metadata.
- if err := state.Load(ctx, r, &f.fileSize, nil); err != nil {
+ if _, err := state.Load(ctx, r, &f.fileSize); err != nil {
return err
}
if err := f.file.Truncate(f.fileSize); err != nil {
@@ -125,7 +126,7 @@ func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error {
}
newMappings := make([]uintptr, f.fileSize>>chunkShift)
f.mappings.Store(newMappings)
- if err := state.Load(ctx, r, &f.usage, nil); err != nil {
+ if _, err := state.Load(ctx, r, &f.usage); err != nil {
return err
}
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
index 2b1350135..089b3bbef 100644
--- a/pkg/state/BUILD
+++ b/pkg/state/BUILD
@@ -1,9 +1,47 @@
-load("//tools:defs.bzl", "go_library", "go_test", "proto_library")
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
go_template_instance(
+ name = "pending_list",
+ out = "pending_list.go",
+ package = "state",
+ prefix = "pending",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*objectEncodeState",
+ "ElementMapper": "pendingMapper",
+ "Linker": "*pendingEntry",
+ },
+)
+
+go_template_instance(
+ name = "deferred_list",
+ out = "deferred_list.go",
+ package = "state",
+ prefix = "deferred",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*objectEncodeState",
+ "ElementMapper": "deferredMapper",
+ "Linker": "*deferredEntry",
+ },
+)
+
+go_template_instance(
+ name = "complete_list",
+ out = "complete_list.go",
+ package = "state",
+ prefix = "complete",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*objectDecodeState",
+ "Linker": "*objectDecodeState",
+ },
+)
+
+go_template_instance(
name = "addr_range",
out = "addr_range.go",
package = "state",
@@ -29,7 +67,7 @@ go_template_instance(
types = {
"Key": "uintptr",
"Range": "addrRange",
- "Value": "reflect.Value",
+ "Value": "*objectEncodeState",
"Functions": "addrSetFunctions",
},
)
@@ -39,32 +77,24 @@ go_library(
srcs = [
"addr_range.go",
"addr_set.go",
+ "complete_list.go",
"decode.go",
+ "decode_unsafe.go",
+ "deferred_list.go",
"encode.go",
"encode_unsafe.go",
- "map.go",
- "printer.go",
+ "pending_list.go",
"state.go",
+ "state_norace.go",
+ "state_race.go",
"stats.go",
+ "types.go",
],
marshal = False,
stateify = False,
visibility = ["//:sandbox"],
deps = [
- ":object_go_proto",
- "@com_github_golang_protobuf//proto:go_default_library",
+ "//pkg/log",
+ "//pkg/state/wire",
],
)
-
-proto_library(
- name = "object",
- srcs = ["object.proto"],
- visibility = ["//:sandbox"],
-)
-
-go_test(
- name = "state_test",
- timeout = "long",
- srcs = ["state_test.go"],
- library = ":state",
-)
diff --git a/pkg/state/README.md b/pkg/state/README.md
new file mode 100644
index 000000000..1aa401193
--- /dev/null
+++ b/pkg/state/README.md
@@ -0,0 +1,158 @@
+# State Encoding and Decoding
+
+The state package implements the encoding and decoding of data structures for
+`go_stateify`. This package is designed for use cases other than the standard
+encoding packages, e.g. `gob` and `json`. Principally:
+
+* This package operates on complex object graphs and accurately serializes and
+ restores all relationships. That is, you can have things like: intrusive
+ pointers, cycles, and pointer chains of arbitrary depths. These are not
+ handled appropriately by existing encoders. This is not an implementation
+ flaw: the formats themselves are not capable of representing these graphs,
+ as they can only generate directed trees.
+
+* This package allows installing order-dependent load callbacks and then
+ resolves that graph at load time, with cycle detection. Similarly, there is
+ no analogous feature possible in the standard encoders.
+
+* This package handles the resolution of interfaces, based on a registered
+ type name. For interface objects type information is saved in the serialized
+ format. This is generally true for `gob` as well, but it works differently.
+
+Here's an overview of how encoding and decoding works.
+
+## Encoding
+
+Encoding produces a `statefile`, which contains a list of chunks of the form
+`(header, payload)`. The payload can either be some raw data, or a series of
+encoded wire objects representing some object graph. All encoded objects are
+defined in the `wire` subpackage.
+
+Encoding of an object graph begins with `encodeState.Save`.
+
+### 1. Memory Map & Encoding
+
+To discover relationships between potentially interdependent data structures
+(for example, a struct may contain pointers to members of other data
+structures), the encoder first walks the object graph and constructs a memory
+map of the objects in the input graph. As this walk progresses, objects are
+queued in the `pending` list and items are placed on the `deferred` list as they
+are discovered. No single object will be encoded multiple times, but the
+discovered relationships between objects may change as more parts of the overall
+object graph are discovered.
+
+The encoder starts at the root object and recursively visits all reachable
+objects, recording the address ranges containing the underlying data for each
+object. This is stored as a segment set (`addrSet`), mapping address ranges to
+the of the object occupying the range; see `encodeState.values`. Note that there
+is special handling for zero-sized types and map objects during this process.
+
+Additionally, the encoder assigns each object a unique identifier which is used
+to indicate relationships between objects in the statefile; see `objectID` in
+`encode.go`.
+
+### 2. Type Serialization
+
+The enoder will subsequently serialize all information about discovered types,
+including field names. These are used during decoding to reconcile these types
+with other internally registered types.
+
+### 3. Object Serialization
+
+With a full address map, and all objects correctly encoded, all object encodings
+are serialized. The assigned `objectID`s aren't explicitly encoded in the
+statefile. The order of object messages in the stream determine their IDs.
+
+### Example
+
+Given the following data structure definitions:
+
+```go
+type system struct {
+ o *outer
+ i *inner
+}
+
+type outer struct {
+ a int64
+ cn *container
+}
+
+type container struct {
+ n uint64
+ elem *inner
+}
+
+type inner struct {
+ c container
+ x, y uint64
+}
+```
+
+Initialized like this:
+
+```go
+o := outer{
+ a: 10,
+ cn: nil,
+}
+i := inner{
+ x: 20,
+ y: 30,
+ c: container{},
+}
+s := system{
+ o: &o,
+ i: &i,
+}
+
+o.cn = &i.c
+o.cn.elem = &i
+
+```
+
+Encoding will produce an object stream like this:
+
+```
+g0r1 = struct{
+ i: g0r3,
+ o: g0r2,
+}
+g0r2 = struct{
+ a: 10,
+ cn: g0r3.c,
+}
+g0r3 = struct{
+ c: struct{
+ elem: g0r3,
+ n: 0u,
+ },
+ x: 20u,
+ y: 30u,
+}
+```
+
+Note how `g0r3.c` is correctly encoded as the underlying `container` object for
+`inner.c`, and how the pointer from `outer.cn` points to it, despite `system.i`
+being discovered after the pointer to it in `system.o.cn`. Also note that
+decoding isn't strictly reliant on the order of encoded object stream, as long
+as the relationship between objects are correctly encoded.
+
+## Decoding
+
+Decoding reads the statefile and reconstructs the object graph. Decoding begins
+in `decodeState.Load`. Decoding is performed in a single pass over the object
+stream in the statefile, and a subsequent pass over all deserialized objects is
+done to fire off all loading callbacks in the correctly defined order. Note that
+introducing cycles is possible here, but these are detected and an error will be
+returned.
+
+Decoding is relatively straight forward. For most primitive values, the decoder
+constructs an appropriate object and fills it with the values encoded in the
+statefile. Pointers need special handling, as they must point to a value
+allocated elsewhere. When values are constructed, the decoder indexes them by
+their `objectID`s in `decodeState.objectsByID`. The target of pointers are
+resolved by searching for the target in this index by their `objectID`; see
+`decodeState.register`. For pointers to values inside another value (fields in a
+pointer, elements of an array), the decoder uses the accessor path to walk to
+the appropriate location; see `walkChild`.
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
index 590c241a3..c9971cdf6 100644
--- a/pkg/state/decode.go
+++ b/pkg/state/decode.go
@@ -17,28 +17,49 @@ package state
import (
"bytes"
"context"
- "encoding/binary"
- "errors"
"fmt"
- "io"
+ "math"
"reflect"
- "sort"
- "github.com/golang/protobuf/proto"
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
+ "gvisor.dev/gvisor/pkg/state/wire"
)
-// objectState represents an object that may be in the process of being
+// internalCallback is a interface called on object completion.
+//
+// There are two implementations: objectDecodeState & userCallback.
+type internalCallback interface {
+ // source returns the dependent object. May be nil.
+ source() *objectDecodeState
+
+ // callbackRun executes the callback.
+ callbackRun()
+}
+
+// userCallback is an implementation of internalCallback.
+type userCallback func()
+
+// source implements internalCallback.source.
+func (userCallback) source() *objectDecodeState {
+ return nil
+}
+
+// callbackRun implements internalCallback.callbackRun.
+func (uc userCallback) callbackRun() {
+ uc()
+}
+
+// objectDecodeState represents an object that may be in the process of being
// decoded. Specifically, it represents either a decoded object, or an an
// interest in a future object that will be decoded. When that interest is
// registered (via register), the storage for the object will be created, but
// it will not be decoded until the object is encountered in the stream.
-type objectState struct {
+type objectDecodeState struct {
// id is the id for this object.
- //
- // If this field is zero, then this is an anonymous (unregistered,
- // non-reference primitive) object. This is immutable.
- id uint64
+ id objectID
+
+ // typ is the id for this typeID. This may be zero if this is not a
+ // type-registered structure.
+ typ typeID
// obj is the object. This may or may not be valid yet, depending on
// whether complete returns true. However, regardless of whether the
@@ -57,69 +78,52 @@ type objectState struct {
// blockedBy is the number of dependencies this object has.
blockedBy int
- // blocking is a list of the objects blocked by this one.
- blocking []*objectState
+ // callbacksInline is inline storage for callbacks.
+ callbacksInline [2]internalCallback
// callbacks is a set of callbacks to execute on load.
- callbacks []func()
-
- // path is the decoding path to the object.
- path recoverable
-}
-
-// complete indicates the object is complete.
-func (os *objectState) complete() bool {
- return os.blockedBy == 0 && len(os.callbacks) == 0
-}
-
-// checkComplete checks for completion. If the object is complete, pending
-// callbacks will be executed and checkComplete will be called on downstream
-// objects (those depending on this one).
-func (os *objectState) checkComplete(stats *Stats) {
- if os.blockedBy > 0 {
- return
- }
- stats.Start(os.obj)
+ callbacks []internalCallback
- // Fire all callbacks.
- for _, fn := range os.callbacks {
- fn()
- }
- os.callbacks = nil
-
- // Clear all blocked objects.
- for _, other := range os.blocking {
- other.blockedBy--
- other.checkComplete(stats)
- }
- os.blocking = nil
- stats.Done()
+ completeEntry
}
-// waitFor queues a dependency on the given object.
-func (os *objectState) waitFor(other *objectState, callback func()) {
- os.blockedBy++
- other.blocking = append(other.blocking, os)
- if callback != nil {
- other.callbacks = append(other.callbacks, callback)
+// addCallback adds a callback to the objectDecodeState.
+func (ods *objectDecodeState) addCallback(ic internalCallback) {
+ if ods.callbacks == nil {
+ ods.callbacks = ods.callbacksInline[:0]
}
+ ods.callbacks = append(ods.callbacks, ic)
}
// findCycleFor returns when the given object is found in the blocking set.
-func (os *objectState) findCycleFor(target *objectState) []*objectState {
- for _, other := range os.blocking {
- if other == target {
- return []*objectState{target}
+func (ods *objectDecodeState) findCycleFor(target *objectDecodeState) []*objectDecodeState {
+ for _, ic := range ods.callbacks {
+ other := ic.source()
+ if other != nil && other == target {
+ return []*objectDecodeState{target}
} else if childList := other.findCycleFor(target); childList != nil {
return append(childList, other)
}
}
- return nil
+
+ // This should not occur.
+ Failf("no deadlock found?")
+ panic("unreachable")
}
// findCycle finds a dependency cycle.
-func (os *objectState) findCycle() []*objectState {
- return append(os.findCycleFor(os), os)
+func (ods *objectDecodeState) findCycle() []*objectDecodeState {
+ return append(ods.findCycleFor(ods), ods)
+}
+
+// source implements internalCallback.source.
+func (ods *objectDecodeState) source() *objectDecodeState {
+ return ods
+}
+
+// callbackRun implements internalCallback.callbackRun.
+func (ods *objectDecodeState) callbackRun() {
+ ods.blockedBy--
}
// decodeState is a graph of objects in the process of being decoded.
@@ -137,30 +141,66 @@ type decodeState struct {
// ctx is the decode context.
ctx context.Context
+ // r is the input stream.
+ r wire.Reader
+
+ // types is the type database.
+ types typeDecodeDatabase
+
// objectByID is the set of objects in progress.
- objectsByID map[uint64]*objectState
+ objectsByID []*objectDecodeState
// deferred are objects that have been read, by no interest has been
// registered yet. These will be decoded once interest in registered.
- deferred map[uint64]*pb.Object
+ deferred map[objectID]wire.Object
- // outstanding is the number of outstanding objects.
- outstanding uint32
+ // pending is the set of objects that are not yet complete.
+ pending completeList
- // r is the input stream.
- r io.Reader
-
- // stats is the passed stats object.
- stats *Stats
-
- // recoverable is the panic recover facility.
- recoverable
+ // stats tracks time data.
+ stats Stats
}
// lookup looks up an object in decodeState or returns nil if no such object
// has been previously registered.
-func (ds *decodeState) lookup(id uint64) *objectState {
- return ds.objectsByID[id]
+func (ds *decodeState) lookup(id objectID) *objectDecodeState {
+ if len(ds.objectsByID) < int(id) {
+ return nil
+ }
+ return ds.objectsByID[id-1]
+}
+
+// checkComplete checks for completion.
+func (ds *decodeState) checkComplete(ods *objectDecodeState) bool {
+ // Still blocked?
+ if ods.blockedBy > 0 {
+ return false
+ }
+
+ // Track stats if relevant.
+ if ods.callbacks != nil && ods.typ != 0 {
+ ds.stats.start(ods.typ)
+ defer ds.stats.done()
+ }
+
+ // Fire all callbacks.
+ for _, ic := range ods.callbacks {
+ ic.callbackRun()
+ }
+
+ // Mark completed.
+ cbs := ods.callbacks
+ ods.callbacks = nil
+ ds.pending.Remove(ods)
+
+ // Recursively check others.
+ for _, ic := range cbs {
+ if other := ic.source(); other != nil && other.blockedBy == 0 {
+ ds.checkComplete(other)
+ }
+ }
+
+ return true // All set.
}
// wait registers a dependency on an object.
@@ -168,11 +208,8 @@ func (ds *decodeState) lookup(id uint64) *objectState {
// As a special case, we always allow _useable_ references back to the first
// decoding object because it may have fields that are already decoded. We also
// allow trivial self reference, since they can be handled internally.
-func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) {
+func (ds *decodeState) wait(waiter *objectDecodeState, id objectID, callback func()) {
switch id {
- case 0:
- // Nil pointer; nothing to wait for.
- fallthrough
case waiter.id:
// Trivial self reference.
fallthrough
@@ -184,107 +221,188 @@ func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) {
return
}
+ // Mark as blocked.
+ waiter.blockedBy++
+
// No nil can be returned here.
- waiter.waitFor(ds.lookup(id), callback)
+ other := ds.lookup(id)
+ if callback != nil {
+ // Add the additional user callback.
+ other.addCallback(userCallback(callback))
+ }
+
+ // Mark waiter as unblocked.
+ other.addCallback(waiter)
}
// waitObject notes a blocking relationship.
-func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) {
- if rv, ok := p.Value.(*pb.Object_RefValue); ok {
+func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, callback func()) {
+ if rv, ok := encoded.(*wire.Ref); ok && rv.Root != 0 {
// Refs can encode pointers and maps.
- ds.wait(os, rv.RefValue, callback)
- } else if sv, ok := p.Value.(*pb.Object_SliceValue); ok {
+ ds.wait(ods, objectID(rv.Root), callback)
+ } else if sv, ok := encoded.(*wire.Slice); ok && sv.Ref.Root != 0 {
// See decodeObject; we need to wait for the array (if non-nil).
- ds.wait(os, sv.SliceValue.RefValue, callback)
- } else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok {
+ ds.wait(ods, objectID(sv.Ref.Root), callback)
+ } else if iv, ok := encoded.(*wire.Interface); ok {
// It's an interface (wait recurisvely).
- ds.waitObject(os, iv.InterfaceValue.Value, callback)
+ ds.waitObject(ods, iv.Value, callback)
} else if callback != nil {
// Nothing to wait for: execute the callback immediately.
callback()
}
}
+// walkChild returns a child object from obj, given an accessor path. This is
+// the decode-side equivalent to traverse in encode.go.
+//
+// For the purposes of this function, a child object is either a field within a
+// struct or an array element, with one such indirection per element in
+// path. The returned value may be an unexported field, so it may not be
+// directly assignable. See unsafePointerTo.
+func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value {
+ // See wire.Ref.Dots. The path here is specified in reverse order.
+ for i := len(path) - 1; i >= 0; i-- {
+ switch pc := path[i].(type) {
+ case *wire.FieldName: // Must be a pointer.
+ if obj.Kind() != reflect.Struct {
+ Failf("next component in child path is a field name, but the current object is not a struct. Path: %v, current obj: %#v", path, obj)
+ }
+ obj = obj.FieldByName(string(*pc))
+ case wire.Index: // Embedded.
+ if obj.Kind() != reflect.Array {
+ Failf("next component in child path is an array index, but the current object is not an array. Path: %v, current obj: %#v", path, obj)
+ }
+ obj = obj.Index(int(pc))
+ default:
+ panic("unreachable: switch should be exhaustive")
+ }
+ }
+ return obj
+}
+
// register registers a decode with a type.
//
// This type is only used to instantiate a new object if it has not been
-// registered previously.
-func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState {
- os, ok := ds.objectsByID[id]
- if ok {
- return os
+// registered previously. This depends on the type provided if none is
+// available in the object itself.
+func (ds *decodeState) register(r *wire.Ref, typ reflect.Type) reflect.Value {
+ // Grow the objectsByID slice.
+ id := objectID(r.Root)
+ if len(ds.objectsByID) < int(id) {
+ ds.objectsByID = append(ds.objectsByID, make([]*objectDecodeState, int(id)-len(ds.objectsByID))...)
+ }
+
+ // Does this object already exist?
+ ods := ds.objectsByID[id-1]
+ if ods != nil {
+ return walkChild(r.Dots, ods.obj)
+ }
+
+ // Create the object.
+ if len(r.Dots) != 0 {
+ typ = ds.findType(r.Type)
}
+ v := reflect.New(typ)
+ ods = &objectDecodeState{
+ id: id,
+ obj: v.Elem(),
+ }
+ ds.objectsByID[id-1] = ods
+ ds.pending.PushBack(ods)
- // Record in the object index.
- if typ.Kind() == reflect.Map {
- os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()}
- } else {
- os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()}
+ // Process any deferred objects & callbacks.
+ if encoded, ok := ds.deferred[id]; ok {
+ delete(ds.deferred, id)
+ ds.decodeObject(ods, ods.obj, encoded)
}
- ds.objectsByID[id] = os
- if o, ok := ds.deferred[id]; ok {
- // There is a deferred object.
- delete(ds.deferred, id) // Free memory.
- ds.decodeObject(os, os.obj, o, "", nil)
- } else {
- // There is no deferred object.
- ds.outstanding++
+ return walkChild(r.Dots, ods.obj)
+}
+
+// objectDecoder is for decoding structs.
+type objectDecoder struct {
+ // ds is decodeState.
+ ds *decodeState
+
+ // ods is current object being decoded.
+ ods *objectDecodeState
+
+ // reconciledTypeEntry is the reconciled type information.
+ rte *reconciledTypeEntry
+
+ // encoded is the encoded object state.
+ encoded *wire.Struct
+}
+
+// load is helper for the public methods on Source.
+func (od *objectDecoder) load(slot int, objPtr reflect.Value, wait bool, fn func()) {
+ // Note that we have reconciled the type and may remap the fields here
+ // to match what's expected by the decoder. The "slot" parameter here
+ // is in terms of the local type, where the fields in the encoded
+ // object are in terms of the wire object's type, which might be in a
+ // different order (but will have the same fields).
+ v := *od.encoded.Field(od.rte.FieldOrder[slot])
+ od.ds.decodeObject(od.ods, objPtr.Elem(), v)
+ if wait {
+ // Mark this individual object a blocker.
+ od.ds.waitObject(od.ods, v, fn)
}
+}
- return os
+// aterLoad implements Source.AfterLoad.
+func (od *objectDecoder) afterLoad(fn func()) {
+ // Queue the local callback; this will execute when all of the above
+ // data dependencies have been cleared.
+ od.ods.addCallback(userCallback(fn))
}
// decodeStruct decodes a struct value.
-func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) {
- // Set the fields.
- m := Map{newInternalMap(nil, ds, os)}
- defer internalMapPool.Put(m.internalMap)
- for _, field := range s.Fields {
- m.data = append(m.data, entry{
- name: field.Name,
- object: field.Value,
- })
- }
-
- // Sort the fields for efficient searching.
- //
- // Technically, these should already appear in sorted order in the
- // state ordering, so this cost is effectively a single scan to ensure
- // that the order is correct.
- if len(m.data) > 1 {
- sort.Slice(m.data, func(i, j int) bool {
- return m.data[i].name < m.data[j].name
- })
- }
-
- // Invoke the load; this will recursively decode other objects.
- fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
- if ok {
- // Invoke the loader.
- fns.invokeLoad(obj.Addr(), m)
- } else if obj.NumField() == 0 {
- // Allow anonymous empty structs.
- return
- } else {
+func (ds *decodeState) decodeStruct(ods *objectDecodeState, obj reflect.Value, encoded *wire.Struct) {
+ if encoded.TypeID == 0 {
+ // Allow anonymous empty structs, but only if the encoded
+ // object also has no fields.
+ if encoded.Fields() == 0 && obj.NumField() == 0 {
+ return
+ }
+
// Propagate an error.
- panic(fmt.Errorf("unregistered type %s", obj.Type()))
+ Failf("empty struct on wire %#v has field mismatch with type %q", encoded, obj.Type().Name())
+ }
+
+ // Lookup the object type.
+ rte := ds.types.Lookup(typeID(encoded.TypeID), obj.Type())
+ ods.typ = typeID(encoded.TypeID)
+
+ // Invoke the loader.
+ od := objectDecoder{
+ ds: ds,
+ ods: ods,
+ rte: rte,
+ encoded: encoded,
+ }
+ ds.stats.start(ods.typ)
+ defer ds.stats.done()
+ if sl, ok := obj.Addr().Interface().(SaverLoader); ok {
+ // Note: may be a registered empty struct which does not
+ // implement the saver/loader interfaces.
+ sl.StateLoad(Source{internal: od})
}
}
// decodeMap decodes a map value.
-func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) {
+func (ds *decodeState) decodeMap(ods *objectDecodeState, obj reflect.Value, encoded *wire.Map) {
if obj.IsNil() {
+ // See pointerTo.
obj.Set(reflect.MakeMap(obj.Type()))
}
- for i := 0; i < len(m.Keys); i++ {
+ for i := 0; i < len(encoded.Keys); i++ {
// Decode the objects.
kv := reflect.New(obj.Type().Key()).Elem()
vv := reflect.New(obj.Type().Elem()).Elem()
- ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i)
- ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface())
- ds.waitObject(os, m.Keys[i], nil)
- ds.waitObject(os, m.Values[i], nil)
+ ds.decodeObject(ods, kv, encoded.Keys[i])
+ ds.decodeObject(ods, vv, encoded.Values[i])
+ ds.waitObject(ods, encoded.Keys[i], nil)
+ ds.waitObject(ods, encoded.Values[i], nil)
// Set in the map.
obj.SetMapIndex(kv, vv)
@@ -292,271 +410,294 @@ func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map)
}
// decodeArray decodes an array value.
-func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) {
- if len(a.Contents) != obj.Len() {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents)))
+func (ds *decodeState) decodeArray(ods *objectDecodeState, obj reflect.Value, encoded *wire.Array) {
+ if len(encoded.Contents) != obj.Len() {
+ Failf("mismatching array length expect=%d, actual=%d", obj.Len(), len(encoded.Contents))
}
// Decode the contents into the array.
- for i := 0; i < len(a.Contents); i++ {
- ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i)
- ds.waitObject(os, a.Contents[i], nil)
+ for i := 0; i < len(encoded.Contents); i++ {
+ ds.decodeObject(ods, obj.Index(i), encoded.Contents[i])
+ ds.waitObject(ods, encoded.Contents[i], nil)
}
}
-// decodeInterface decodes an interface value.
-func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) {
- // Is this a nil value?
- if i.Type == "" {
- return // Just leave obj alone.
+// findType finds the type for the given wire.TypeSpecs.
+func (ds *decodeState) findType(t wire.TypeSpec) reflect.Type {
+ switch x := t.(type) {
+ case wire.TypeID:
+ typ := ds.types.LookupType(typeID(x))
+ rte := ds.types.Lookup(typeID(x), typ)
+ return rte.LocalType
+ case *wire.TypeSpecPointer:
+ return reflect.PtrTo(ds.findType(x.Type))
+ case *wire.TypeSpecArray:
+ return reflect.ArrayOf(int(x.Count), ds.findType(x.Type))
+ case *wire.TypeSpecSlice:
+ return reflect.SliceOf(ds.findType(x.Type))
+ case *wire.TypeSpecMap:
+ return reflect.MapOf(ds.findType(x.Key), ds.findType(x.Value))
+ default:
+ // Should not happen.
+ Failf("unknown type %#v", t)
}
+ panic("unreachable")
+}
- // Get the dispatchable type. This may not be used if the given
- // reference has already been resolved, but if not we need to know the
- // type to create.
- t, ok := registeredTypes.lookupType(i.Type)
- if !ok {
- panic(fmt.Errorf("no valid type for %q", i.Type))
+// decodeInterface decodes an interface value.
+func (ds *decodeState) decodeInterface(ods *objectDecodeState, obj reflect.Value, encoded *wire.Interface) {
+ if _, ok := encoded.Type.(wire.TypeSpecNil); ok {
+ // Special case; the nil object. Just decode directly, which
+ // will read nil from the wire (if encoded correctly).
+ ds.decodeObject(ods, obj, encoded.Value)
+ return
}
- if obj.Kind() != reflect.Map {
- // Set the obj to be the given typed value; this actually sets
- // obj to be a non-zero value -- namely, it inserts type
- // information. There's no need to do this for maps.
- obj.Set(reflect.Zero(t))
+ // We now need to resolve the actual type.
+ typ := ds.findType(encoded.Type)
+
+ // We need to imbue type information here, then we can proceed to
+ // decode normally. In order to avoid issues with setting value-types,
+ // we create a new non-interface version of this object. We will then
+ // set the interface object to be equal to whatever we decode.
+ origObj := obj
+ obj = reflect.New(typ).Elem()
+ defer origObj.Set(obj)
+
+ // With the object now having sufficient type information to actually
+ // have Set called on it, we can proceed to decode the value.
+ ds.decodeObject(ods, obj, encoded.Value)
+}
+
+// isFloatEq determines if x and y represent the same value.
+func isFloatEq(x float64, y float64) bool {
+ switch {
+ case math.IsNaN(x):
+ return math.IsNaN(y)
+ case math.IsInf(x, 1):
+ return math.IsInf(y, 1)
+ case math.IsInf(x, -1):
+ return math.IsInf(y, -1)
+ default:
+ return x == y
}
+}
- // Decode the dereferenced element; there is no need to wait here, as
- // the interface object shares the current object state.
- ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type)
+// isComplexEq determines if x and y represent the same value.
+func isComplexEq(x complex128, y complex128) bool {
+ return isFloatEq(real(x), real(y)) && isFloatEq(imag(x), imag(y))
}
// decodeObject decodes a object value.
-func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) {
- ds.push(false, format, param)
- ds.stats.Add(obj)
- ds.stats.Start(obj)
-
- switch x := object.GetValue().(type) {
- case *pb.Object_BoolValue:
- obj.SetBool(x.BoolValue)
- case *pb.Object_StringValue:
- obj.SetString(string(x.StringValue))
- case *pb.Object_Int64Value:
- obj.SetInt(x.Int64Value)
- if obj.Int() != x.Int64Value {
- panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type()))
+func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, encoded wire.Object) {
+ switch x := encoded.(type) {
+ case wire.Nil: // Fast path: first.
+ // We leave obj alone here. That's because if obj represents an
+ // interface, it may have been imbued with type information in
+ // decodeInterface, and we don't want to destroy that.
+ case *wire.Ref:
+ // Nil pointers may be encoded in a "forceValue" context. For
+ // those we just leave it alone as the value will already be
+ // correct (nil).
+ if id := objectID(x.Root); id == 0 {
+ return
}
- case *pb.Object_Uint64Value:
- obj.SetUint(x.Uint64Value)
- if obj.Uint() != x.Uint64Value {
- panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type()))
- }
- case *pb.Object_DoubleValue:
- obj.SetFloat(x.DoubleValue)
- if obj.Float() != x.DoubleValue {
- panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type()))
- }
- case *pb.Object_RefValue:
- // Resolve the pointer itself, even though the object may not
- // be decoded yet. You need to use wait() in order to ensure
- // that is the case. See wait above, and Map.Barrier.
- if id := x.RefValue; id != 0 {
- // Decoding the interface should have imparted type
- // information, so from this point it's safe to resolve
- // and use this dynamic information for actually
- // creating the object in register.
- //
- // (For non-interfaces this is a no-op).
- dyntyp := reflect.TypeOf(obj.Interface())
- if dyntyp.Kind() == reflect.Map {
- // Remove the map object count here to avoid
- // double counting, as this object will be
- // counted again when it gets processed later.
- // We do not add a reference count as the
- // reference is artificial.
- ds.stats.Remove(obj)
- obj.Set(ds.register(id, dyntyp).obj)
- } else if dyntyp.Kind() == reflect.Ptr {
- ds.push(true /* dereference */, "", nil)
- obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
- ds.pop()
- } else {
- obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
+
+ // Note that if this is a map type, we go through a level of
+ // indirection to allow for map aliasing.
+ if obj.Kind() == reflect.Map {
+ v := ds.register(x, obj.Type())
+ if v.IsNil() {
+ // Note that we don't want to clobber the map
+ // if has already been decoded by decodeMap. We
+ // just make it so that we have a consistent
+ // reference when that eventually does happen.
+ v.Set(reflect.MakeMap(v.Type()))
}
- } else {
- // We leave obj alone here. That's because if obj
- // represents an interface, it may have been embued
- // with type information in decodeInterface, and we
- // don't want to destroy that information.
+ obj.Set(v)
+ return
}
- case *pb.Object_SliceValue:
- // It's okay to slice the array here, since the contents will
- // still be provided later on. These semantics are a bit
- // strange but they are handled in the Map.Barrier properly.
- //
- // The special semantics of zero ref apply here too.
- if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 {
- v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem())
- obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity)))
+
+ // Normal assignment: authoritative only if no dots.
+ v := ds.register(x, obj.Type().Elem())
+ if v.IsValid() {
+ obj.Set(unsafePointerTo(v))
}
- case *pb.Object_ArrayValue:
- ds.decodeArray(os, obj, x.ArrayValue)
- case *pb.Object_StructValue:
- ds.decodeStruct(os, obj, x.StructValue)
- case *pb.Object_MapValue:
- ds.decodeMap(os, obj, x.MapValue)
- case *pb.Object_InterfaceValue:
- ds.decodeInterface(os, obj, x.InterfaceValue)
- case *pb.Object_ByteArrayValue:
- copyArray(obj, reflect.ValueOf(x.ByteArrayValue))
- case *pb.Object_Uint16ArrayValue:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := x.Uint16ArrayValue.Values
- t := obj.Slice(0, obj.Len()).Interface().([]uint16)
- if len(t) != len(s) {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
+ case wire.Bool:
+ obj.SetBool(bool(x))
+ case wire.Int:
+ obj.SetInt(int64(x))
+ if obj.Int() != int64(x) {
+ Failf("signed integer truncated from %v to %v", int64(x), obj.Int())
}
- for i := range s {
- t[i] = uint16(s[i])
+ case wire.Uint:
+ obj.SetUint(uint64(x))
+ if obj.Uint() != uint64(x) {
+ Failf("unsigned integer truncated from %v to %v", uint64(x), obj.Uint())
}
- case *pb.Object_Uint32ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values))
- case *pb.Object_Uint64ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values))
- case *pb.Object_UintptrArrayValue:
- copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0))))
- case *pb.Object_Int8ArrayValue:
- copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0))))
- case *pb.Object_Int16ArrayValue:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := x.Int16ArrayValue.Values
- t := obj.Slice(0, obj.Len()).Interface().([]int16)
- if len(t) != len(s) {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
+ case wire.Float32:
+ obj.SetFloat(float64(x))
+ case wire.Float64:
+ obj.SetFloat(float64(x))
+ if !isFloatEq(obj.Float(), float64(x)) {
+ Failf("floating point number truncated from %v to %v", float64(x), obj.Float())
}
- for i := range s {
- t[i] = int16(s[i])
+ case *wire.Complex64:
+ obj.SetComplex(complex128(*x))
+ case *wire.Complex128:
+ obj.SetComplex(complex128(*x))
+ if !isComplexEq(obj.Complex(), complex128(*x)) {
+ Failf("complex number truncated from %v to %v", complex128(*x), obj.Complex())
}
- case *pb.Object_Int32ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values))
- case *pb.Object_Int64ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values))
- case *pb.Object_BoolArrayValue:
- copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values))
- case *pb.Object_Float64ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values))
- case *pb.Object_Float32ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values))
+ case *wire.String:
+ obj.SetString(string(*x))
+ case *wire.Slice:
+ // See *wire.Ref above; same applies.
+ if id := objectID(x.Ref.Root); id == 0 {
+ return
+ }
+ // Note that it's fine to slice the array here and assume that
+ // contents will still be filled in later on.
+ typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type.
+ v := ds.register(&x.Ref, typ)
+ obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity)))
+ case *wire.Array:
+ ds.decodeArray(ods, obj, x)
+ case *wire.Struct:
+ ds.decodeStruct(ods, obj, x)
+ case *wire.Map:
+ ds.decodeMap(ods, obj, x)
+ case *wire.Interface:
+ ds.decodeInterface(ods, obj, x)
default:
// Shoud not happen, not propagated as an error.
- panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type()))
- }
-
- ds.stats.Done()
- ds.pop()
-}
-
-func copyArray(dest reflect.Value, src reflect.Value) {
- if dest.Len() != src.Len() {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len()))
+ Failf("unknown object %#v for %q", encoded, obj.Type().Name())
}
- reflect.Copy(dest, castSlice(src, dest.Type().Elem()))
}
-// Deserialize deserializes the object state.
+// Load deserializes the object graph rooted at obj.
//
// This function may panic and should be run in safely().
-func (ds *decodeState) Deserialize(obj reflect.Value) {
- ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()}
- ds.outstanding = 1 // The root object.
+func (ds *decodeState) Load(obj reflect.Value) {
+ ds.stats.init()
+ defer ds.stats.fini(func(id typeID) string {
+ return ds.types.LookupName(id)
+ })
+
+ // Create the root object.
+ ds.objectsByID = append(ds.objectsByID, &objectDecodeState{
+ id: 1,
+ obj: obj,
+ })
+
+ // Read the number of objects.
+ lastID, object, err := ReadHeader(ds.r)
+ if err != nil {
+ Failf("header error: %w", err)
+ }
+ if !object {
+ Failf("object missing")
+ }
+
+ // Decode all objects.
+ var (
+ encoded wire.Object
+ ods *objectDecodeState
+ id = objectID(1)
+ tid = typeID(1)
+ )
+ if err := safely(func() {
+ // Decode all objects in the stream.
+ //
+ // Note that the structure of this decoding loop should match
+ // the raw decoding loop in printer.go.
+ for id <= objectID(lastID) {
+ // Unmarshal the object.
+ encoded = wire.Load(ds.r)
+
+ // Is this a type object? Handle inline.
+ if wt, ok := encoded.(*wire.Type); ok {
+ ds.types.Register(wt)
+ tid++
+ encoded = nil
+ continue
+ }
- // Decode all objects in the stream.
- //
- // See above, we never process objects while we have no outstanding
- // interests (other than the very first object).
- for id := uint64(1); ds.outstanding > 0; id++ {
- os := ds.lookup(id)
- ds.stats.Start(os.obj)
-
- o, err := ds.readObject()
- if err != nil {
- panic(err)
- }
+ // Actually resolve the object.
+ ods = ds.lookup(id)
+ if ods != nil {
+ // Decode the object.
+ ds.decodeObject(ods, ods.obj, encoded)
+ } else {
+ // If an object hasn't had interest registered
+ // previously or isn't yet valid, we deferred
+ // decoding until interest is registered.
+ ds.deferred[id] = encoded
+ }
- if os != nil {
- // Decode the object.
- ds.from = &os.path
- ds.decodeObject(os, os.obj, o, "", nil)
- ds.outstanding--
+ // For error handling.
+ ods = nil
+ encoded = nil
+ id++
+ }
+ }); err != nil {
+ // Include as much information as we can, taking into account
+ // the possible state transitions above.
+ if ods != nil {
+ Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err)
+ } else if encoded != nil {
+ Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err)
} else {
- // If an object hasn't had interest registered
- // previously, we deferred decoding until interest is
- // registered.
- ds.deferred[id] = o
+ Failf("general decoding error: %w", err)
}
-
- ds.stats.Done()
- }
-
- // Check the zero-length header at the end.
- length, object, err := ReadHeader(ds.r)
- if err != nil {
- panic(err)
- }
- if length != 0 {
- panic(fmt.Sprintf("expected zero-length terminal, got %d", length))
- }
- if object {
- panic("expected non-object terminal")
}
// Check if we have any deferred objects.
- if count := len(ds.deferred); count > 0 {
- // Shoud not happen, not propagated as an error.
- panic(fmt.Sprintf("still have %d deferred objects", count))
- }
-
- // Scan and fire all callbacks.
- for _, os := range ds.objectsByID {
- os.checkComplete(ds.stats)
+ for id, encoded := range ds.deferred {
+ // Shoud never happen, the graph was bogus.
+ Failf("still have deferred objects: one is ID %d, %#v", id, encoded)
}
- // Check if we have any remaining dependency cycles.
- for _, os := range ds.objectsByID {
- if !os.complete() {
- // This must be the result of a dependency cycle.
- cycle := os.findCycle()
- var buf bytes.Buffer
- buf.WriteString("dependency cycle: {")
- for i, cycleOS := range cycle {
- if i > 0 {
- buf.WriteString(" => ")
+ // Scan and fire all callbacks. We iterate over the list of incomplete
+ // objects until all have been finished. We stop iterating if no
+ // objects become complete (there is a dependency cycle).
+ //
+ // Note that we iterate backwards here, because there will be a strong
+ // tendendcy for blocking relationships to go from earlier objects to
+ // later (deeper) objects in the graph. This will reduce the number of
+ // iterations required to finish all objects.
+ if err := safely(func() {
+ for ds.pending.Back() != nil {
+ thisCycle := false
+ for ods = ds.pending.Back(); ods != nil; {
+ if ds.checkComplete(ods) {
+ thisCycle = true
+ break
}
- buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type()))
+ ods = ods.Prev()
+ }
+ if !thisCycle {
+ break
}
- buf.WriteString("}")
- // Panic as an error; propagate to the caller.
- panic(errors.New(string(buf.Bytes())))
}
- }
-}
-
-type byteReader struct {
- io.Reader
-}
-
-// ReadByte implements io.ByteReader.
-func (br byteReader) ReadByte() (byte, error) {
- var b [1]byte
- n, err := br.Reader.Read(b[:])
- if n > 0 {
- return b[0], nil
- } else if err != nil {
- return 0, err
- } else {
- return 0, io.ErrUnexpectedEOF
+ }); err != nil {
+ Failf("error executing callbacks for %#v: %w", ods.obj.Interface(), err)
+ }
+
+ // Check if we have any remaining dependency cycles. If there are any
+ // objects left in the pending list, then it must be due to a cycle.
+ if ods := ds.pending.Front(); ods != nil {
+ // This must be the result of a dependency cycle.
+ cycle := ods.findCycle()
+ var buf bytes.Buffer
+ buf.WriteString("dependency cycle: {")
+ for i, cycleOS := range cycle {
+ if i > 0 {
+ buf.WriteString(" => ")
+ }
+ fmt.Fprintf(&buf, "%q", cycleOS.obj.Type())
+ }
+ buf.WriteString("}")
+ Failf("incomplete graph: %s", string(buf.Bytes()))
}
}
@@ -565,45 +706,20 @@ func (br byteReader) ReadByte() (byte, error) {
// Each object written to the statefile is prefixed with a header. See
// WriteHeader for more information; these functions are exported to allow
// non-state writes to the file to play nice with debugging tools.
-func ReadHeader(r io.Reader) (length uint64, object bool, err error) {
+func ReadHeader(r wire.Reader) (length uint64, object bool, err error) {
// Read the header.
- length, err = binary.ReadUvarint(byteReader{r})
+ err = safely(func() {
+ length = wire.LoadUint(r)
+ })
if err != nil {
- return
+ // On the header, pass raw I/O errors.
+ if sErr, ok := err.(*ErrState); ok {
+ return 0, false, sErr.Unwrap()
+ }
}
// Decode whether the object is valid.
- object = length&0x1 != 0
- length = length >> 1
+ object = length&objectFlag != 0
+ length &^= objectFlag
return
}
-
-// readObject reads an object from the stream.
-func (ds *decodeState) readObject() (*pb.Object, error) {
- // Read the header.
- length, object, err := ReadHeader(ds.r)
- if err != nil {
- return nil, err
- }
- if !object {
- return nil, fmt.Errorf("invalid object header")
- }
-
- // Read the object.
- buf := make([]byte, length)
- for done := 0; done < len(buf); {
- n, err := ds.r.Read(buf[done:])
- done += n
- if n == 0 && err != nil {
- return nil, err
- }
- }
-
- // Unmarshal.
- obj := new(pb.Object)
- if err := proto.Unmarshal(buf, obj); err != nil {
- return nil, err
- }
-
- return obj, nil
-}
diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go
new file mode 100644
index 000000000..d048f61a1
--- /dev/null
+++ b/pkg/state/decode_unsafe.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on
+// values representing unexported fields. This bypasses visibility, but not
+// type safety.
+func unsafePointerTo(obj reflect.Value) reflect.Value {
+ return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr()))
+}
diff --git a/pkg/state/encode.go b/pkg/state/encode.go
index c5118d3a9..92fcad4e9 100644
--- a/pkg/state/encode.go
+++ b/pkg/state/encode.go
@@ -15,437 +15,797 @@
package state
import (
- "container/list"
"context"
- "encoding/binary"
- "fmt"
- "io"
"reflect"
- "sort"
- "github.com/golang/protobuf/proto"
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
+ "gvisor.dev/gvisor/pkg/state/wire"
)
-// queuedObject is an object queued for encoding.
-type queuedObject struct {
- id uint64
- obj reflect.Value
- path recoverable
+// objectEncodeState the type and identity of an object occupying a memory
+// address range. This is the value type for addrSet, and the intrusive entry
+// for the pending and deferred lists.
+type objectEncodeState struct {
+ // id is the assigned ID for this object.
+ id objectID
+
+ // obj is the object value. Note that this may be replaced if we
+ // encounter an object that contains this object. When this happens (in
+ // resolve), we will update existing references approprately, below,
+ // and defer a re-encoding of the object.
+ obj reflect.Value
+
+ // encoded is the encoded value of this object. Note that this may not
+ // be up to date if this object is still in the deferred list.
+ encoded wire.Object
+
+ // how indicates whether this object should be encoded as a value. This
+ // is used only for deferred encoding.
+ how encodeStrategy
+
+ // refs are the list of reference objects used by other objects
+ // referring to this object. When the object is updated, these
+ // references may be updated directly and automatically.
+ refs []*wire.Ref
+
+ pendingEntry
+ deferredEntry
}
// encodeState is state used for encoding.
//
-// The encoding process is a breadth-first traversal of the object graph. The
-// inherent races and dependencies are much simpler than the decode case.
+// The encoding process constructs a representation of the in-memory graph of
+// objects before a single object is serialized. This is done to ensure that
+// all references can be fully disambiguated. See resolve for more details.
type encodeState struct {
// ctx is the encode context.
ctx context.Context
- // lastID is the last object ID.
- //
- // See idsByObject for context. Because of the special zero encoding
- // used for reference values, the first ID must be 1.
- lastID uint64
+ // w is the output stream.
+ w wire.Writer
- // idsByObject is a set of objects, indexed via:
- //
- // reflect.ValueOf(x).UnsafeAddr
- //
- // This provides IDs for objects.
- idsByObject map[uintptr]uint64
+ // types is the type database.
+ types typeEncodeDatabase
+
+ // lastID is the last allocated object ID.
+ lastID objectID
- // values stores values that span the addresses.
+ // values tracks the address ranges occupied by objects, along with the
+ // types of these objects. This is used to locate pointer targets,
+ // including pointers to fields within another type.
//
- // addrSet is a a generated type which efficiently stores ranges of
- // addresses. When encoding pointers, these ranges are filled in and
- // used to check for overlapping or conflicting pointers. This would
- // indicate a pointer to an field, or a non-type safe value, neither of
- // which are currently decodable.
+ // Multiple objects may overlap in memory iff the larger object fully
+ // contains the smaller one, and the type of the smaller object matches
+ // a field or array element's type at the appropriate offset. An
+ // arbitrary number of objects may be nested in this manner.
//
- // See the usage of values below for more context.
+ // Note that this does not track zero-sized objects, those are tracked
+ // by zeroValues below.
values addrSet
- // w is the output stream.
- w io.Writer
+ // zeroValues tracks zero-sized objects.
+ zeroValues map[reflect.Type]*objectEncodeState
- // pending is the list of objects to be serialized.
- //
- // This is a set of queuedObjects.
- pending list.List
+ // deferred is the list of objects to be encoded.
+ deferred deferredList
- // done is the a list of finished objects.
- //
- // This is kept to prevent garbage collection and address reuse.
- done list.List
+ // pendingTypes is the list of types to be serialized. Serialization
+ // will occur when all objects have been encoded, but before pending is
+ // serialized.
+ pendingTypes []wire.Type
- // stats is the passed stats object.
- stats *Stats
+ // pending is the list of objects to be serialized. Serialization does
+ // not actually occur until the full object graph is computed.
+ pending pendingList
- // recoverable is the panic recover facility.
- recoverable
+ // stats tracks time data.
+ stats Stats
}
-// register looks up an ID, registering if necessary.
+// isSameSizeParent returns true if child is a field value or element within
+// parent. Only a struct or array can have a child value.
+//
+// isSameSizeParent deals with objects like this:
+//
+// struct child {
+// // fields..
+// }
//
-// If the object was not previously registered, it is enqueued to be serialized.
-// See the documentation for idsByObject for more information.
-func (es *encodeState) register(obj reflect.Value) uint64 {
- // It is not legal to call register for any non-pointer objects (see
- // below), so we panic with a recoverable error if this is a mismatch.
- if obj.Kind() != reflect.Ptr && obj.Kind() != reflect.Map {
- panic(fmt.Errorf("non-pointer %#v registered", obj.Interface()))
+// struct parent {
+// c child
+// }
+//
+// var p parent
+// record(&p.c)
+//
+// Here, &p and &p.c occupy the exact same address range.
+//
+// Or like this:
+//
+// struct child {
+// // fields
+// }
+//
+// var arr [1]parent
+// record(&arr[0])
+//
+// Similarly, &arr[0] and &arr[0].c have the exact same address range.
+//
+// Precondition: parent and child must occupy the same memory.
+func isSameSizeParent(parent reflect.Value, childType reflect.Type) bool {
+ switch parent.Kind() {
+ case reflect.Struct:
+ for i := 0; i < parent.NumField(); i++ {
+ field := parent.Field(i)
+ if field.Type() == childType {
+ return true
+ }
+ // Recurse through any intermediate types.
+ if isSameSizeParent(field, childType) {
+ return true
+ }
+ // Does it make sense to keep going if the first field
+ // doesn't match? Yes, because there might be an
+ // arbitrary number of zero-sized fields before we get
+ // a match, and childType itself can be zero-sized.
+ }
+ return false
+ case reflect.Array:
+ // The only case where an array with more than one elements can
+ // return true is if childType is zero-sized. In such cases,
+ // it's ambiguous which element contains the match since a
+ // zero-sized child object fully fits in any of the zero-sized
+ // elements in an array... However since all elements are of
+ // the same type, we only need to check one element.
+ //
+ // For non-zero-sized childTypes, parent.Len() must be 1, but a
+ // combination of the precondition and an implicit comparison
+ // between the array element size and childType ensures this.
+ return parent.Len() > 0 && isSameSizeParent(parent.Index(0), childType)
+ default:
+ return false
}
+}
- addr := obj.Pointer()
- if obj.Kind() == reflect.Ptr && obj.Elem().Type().Size() == 0 {
- // For zero-sized objects, we always provide a unique ID.
- // That's because the runtime internally multiplexes pointers
- // to the same address. We can't be certain what the intent is
- // with pointers to zero-sized objects, so we just give them
- // all unique identities.
- } else if id, ok := es.idsByObject[addr]; ok {
- // Already registered.
- return id
- }
-
- // Ensure that the first ID given out is one. See note on lastID. The
- // ID zero is used to indicate nil values.
+// nextID returns the next valid ID.
+func (es *encodeState) nextID() objectID {
es.lastID++
- id := es.lastID
- es.idsByObject[addr] = id
- if obj.Kind() == reflect.Ptr {
- // Dereference and treat as a pointer.
- es.pending.PushBack(queuedObject{id: id, obj: obj.Elem(), path: es.recoverable.copy()})
-
- // Register this object at all addresses.
- typ := obj.Elem().Type()
- if size := typ.Size(); size > 0 {
- r := addrRange{addr, addr + size}
- if !es.values.IsEmptyRange(r) {
- old := es.values.LowerBoundSegment(addr).Value().Interface().(recoverable)
- panic(fmt.Errorf("overlapping objects: [new object] %#v [existing object path] %s", obj.Interface(), old.path()))
+ return objectID(es.lastID)
+}
+
+// dummyAddr points to the dummy zero-sized address.
+var dummyAddr = reflect.ValueOf(new(struct{})).Pointer()
+
+// resolve records the address range occupied by an object.
+func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
+ addr := obj.Pointer()
+
+ // Is this a map pointer? Just record the single address. It is not
+ // possible to take any pointers into the map internals.
+ if obj.Kind() == reflect.Map {
+ if addr == 0 {
+ // Just leave the nil reference alone. This is fine, we
+ // may need to encode as a reference in this way. We
+ // return nil for our objectEncodeState so that anyone
+ // depending on this value knows there's nothing there.
+ return
+ }
+ if seg, _ := es.values.Find(addr); seg.Ok() {
+ // Ensure the map types match.
+ existing := seg.Value()
+ if existing.obj.Type() != obj.Type() {
+ Failf("overlapping map objects at 0x%x: [new object] %#v [existing object type] %s", addr, obj, existing.obj)
}
- es.values.Add(r, reflect.ValueOf(es.recoverable.copy()))
+
+ // No sense recording refs, maps may not be replaced by
+ // covering objects, they are maximal.
+ ref.Root = wire.Uint(existing.id)
+ return
}
+
+ // Record the map.
+ oes := &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ how: encodeMapAsValue,
+ }
+ es.values.Add(addrRange{addr, addr + 1}, oes)
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+
+ // See above: no ref recording.
+ ref.Root = wire.Uint(oes.id)
+ return
+ }
+
+ // If not a map, then the object must be a pointer.
+ if obj.Kind() != reflect.Ptr {
+ Failf("attempt to record non-map and non-pointer object %#v", obj)
+ }
+
+ obj = obj.Elem() // Value from here.
+
+ // Is this a zero-sized type?
+ typ := obj.Type()
+ size := typ.Size()
+ if size == 0 {
+ if addr == dummyAddr {
+ // Zero-sized objects point to a dummy byte within the
+ // runtime. There's no sense recording this in the
+ // address map. We add this to the dedicated
+ // zeroValues.
+ //
+ // Note that zero-sized objects must be *true*
+ // zero-sized objects. They cannot be part of some
+ // larger object. In that case, they are assigned a
+ // 1-byte address at the end of the object.
+ oes, ok := es.zeroValues[typ]
+ if !ok {
+ oes = &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ }
+ es.zeroValues[typ] = oes
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+ }
+
+ // There's also no sense tracking back references. We
+ // know that this is a true zero-sized object, and not
+ // part of a larger container, so it will not change.
+ ref.Root = wire.Uint(oes.id)
+ return
+ }
+ size = 1 // See above.
+ }
+
+ // Calculate the container.
+ end := addr + size
+ r := addrRange{addr, end}
+ if seg, _ := es.values.Find(addr); seg.Ok() {
+ existing := seg.Value()
+ switch {
+ case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type():
+ // The object is a perfect match. Happy path. Avoid the
+ // traversal and just return directly. We don't need to
+ // encode the type information or any dots here.
+ ref.Root = wire.Uint(existing.id)
+ existing.refs = append(existing.refs, ref)
+ return
+
+ case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end):
+ // The previously registered object is larger than
+ // this, no need to update. But we expect some
+ // traversal below.
+
+ case seg.Start() == addr && seg.End() == end:
+ if !isSameSizeParent(obj, existing.obj.Type()) {
+ break // Needs traversal.
+ }
+ fallthrough // Needs update.
+
+ case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end):
+ // Update the object and redo the encoding.
+ old := existing.obj
+ existing.obj = obj
+ es.deferred.Remove(existing)
+ es.deferred.PushBack(existing)
+
+ // The previously registered object is superseded by
+ // this new object. We are guaranteed to not have any
+ // mergeable neighbours in this segment set.
+ if !raceEnabled {
+ seg.SetRangeUnchecked(r)
+ } else {
+ // Add extra paranoid. This will be statically
+ // removed at compile time unless a race build.
+ es.values.Remove(seg)
+ es.values.Add(r, existing)
+ seg = es.values.LowerBoundSegment(addr)
+ }
+
+ // Compute the traversal required & update references.
+ dots := traverse(obj.Type(), old.Type(), addr, seg.Start())
+ wt := es.findType(obj.Type())
+ for _, ref := range existing.refs {
+ ref.Dots = append(ref.Dots, dots...)
+ ref.Type = wt
+ }
+ default:
+ // There is a non-sensical overlap.
+ Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj)
+ }
+
+ // Compute the new reference, record and return it.
+ ref.Root = wire.Uint(existing.id)
+ ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr)
+ ref.Type = es.findType(obj.Type())
+ existing.refs = append(existing.refs, ref)
+ return
+ }
+
+ // The only remaining case is a pointer value that doesn't overlap with
+ // any registered addresses. Create a new entry for it, and start
+ // tracking the first reference we just created.
+ oes := &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ }
+ if !raceEnabled {
+ es.values.AddWithoutMerging(r, oes)
} else {
- // Push back the map itself; when maps are encoded from the
- // top-level, forceMap will be equal to true.
- es.pending.PushBack(queuedObject{id: id, obj: obj, path: es.recoverable.copy()})
+ // Merges should never happen. This is just enabled extra
+ // sanity checks because the Merge function below will panic.
+ es.values.Add(r, oes)
+ }
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+ ref.Root = wire.Uint(oes.id)
+ oes.refs = append(oes.refs, ref)
+}
+
+// traverse searches for a target object within a root object, where the target
+// object is a struct field or array element within root, with potentially
+// multiple intervening types. traverse returns the set of field or element
+// traversals required to reach the target.
+//
+// Note that for efficiency, traverse returns the dots in the reverse order.
+// That is, the first traversal required will be the last element of the list.
+//
+// Precondition: The target object must lie completely within the range defined
+// by [rootAddr, rootAddr + sizeof(rootType)].
+func traverse(rootType, targetType reflect.Type, rootAddr, targetAddr uintptr) []wire.Dot {
+ // Recursion base case: the types actually match.
+ if targetType == rootType && targetAddr == rootAddr {
+ return nil
}
- return id
+ switch rootType.Kind() {
+ case reflect.Struct:
+ offset := targetAddr - rootAddr
+ for i := rootType.NumField(); i > 0; i-- {
+ field := rootType.Field(i - 1)
+ // The first field from the end with an offset that is
+ // smaller than or equal to our address offset is where
+ // the target is located. Traverse from there.
+ if field.Offset <= offset {
+ dots := traverse(field.Type, targetType, rootAddr+field.Offset, targetAddr)
+ fieldName := wire.FieldName(field.Name)
+ return append(dots, &fieldName)
+ }
+ }
+ // Should never happen; the target should be reachable.
+ Failf("no field in root type %v contains target type %v", rootType, targetType)
+
+ case reflect.Array:
+ // Since arrays have homogenous types, all elements have the
+ // same size and we can compute where the target lives. This
+ // does not matter for the purpose of typing, but matters for
+ // the purpose of computing the address of the given index.
+ elemSize := int(rootType.Elem().Size())
+ n := int(targetAddr-rootAddr) / elemSize // Relies on integer division rounding down.
+ if rootType.Len() < n {
+ Failf("traversal target of type %v @%x is beyond the end of the array type %v @%x with %v elements",
+ targetType, targetAddr, rootType, rootAddr, rootType.Len())
+ }
+ dots := traverse(rootType.Elem(), targetType, rootAddr+uintptr(n*elemSize), targetAddr)
+ return append(dots, wire.Index(n))
+
+ default:
+ // For any other type, there's no possibility of aliasing so if
+ // the types didn't match earlier then we have an addresss
+ // collision which shouldn't be possible at this point.
+ Failf("traverse failed for root type %v and target type %v", rootType, targetType)
+ }
+ panic("unreachable")
}
// encodeMap encodes a map.
-func (es *encodeState) encodeMap(obj reflect.Value) *pb.Map {
- var (
- keys []*pb.Object
- values []*pb.Object
- )
+func (es *encodeState) encodeMap(obj reflect.Value, dest *wire.Object) {
+ if obj.IsNil() {
+ // Because there is a difference between a nil map and an empty
+ // map, we need to not decode in the case of a truly nil map.
+ *dest = wire.Nil{}
+ return
+ }
+ l := obj.Len()
+ m := &wire.Map{
+ Keys: make([]wire.Object, l),
+ Values: make([]wire.Object, l),
+ }
+ *dest = m
for i, k := range obj.MapKeys() {
v := obj.MapIndex(k)
- kp := es.encodeObject(k, false, ".(key %d)", i)
- vp := es.encodeObject(v, false, "[%#v]", k.Interface())
- keys = append(keys, kp)
- values = append(values, vp)
+ // Map keys must be encoded using the full value because the
+ // type will be omitted after the first key.
+ es.encodeObject(k, encodeAsValue, &m.Keys[i])
+ es.encodeObject(v, encodeAsValue, &m.Values[i])
}
- return &pb.Map{Keys: keys, Values: values}
+}
+
+// objectEncoder is for encoding structs.
+type objectEncoder struct {
+ // es is encodeState.
+ es *encodeState
+
+ // encoded is the encoded struct.
+ encoded *wire.Struct
+}
+
+// save is called by the public methods on Sink.
+func (oe *objectEncoder) save(slot int, obj reflect.Value) {
+ fieldValue := oe.encoded.Field(slot)
+ oe.es.encodeObject(obj, encodeDefault, fieldValue)
}
// encodeStruct encodes a composite object.
-func (es *encodeState) encodeStruct(obj reflect.Value) *pb.Struct {
- // Invoke the save.
- m := Map{newInternalMap(es, nil, nil)}
- defer internalMapPool.Put(m.internalMap)
+func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) {
+ // Ensure that the obj is addressable. There are two cases when it is
+ // not. First, is when this is dispatched via SaveValue. Second, when
+ // this is a map key as a struct. Either way, we need to make a copy to
+ // obtain an addressable value.
if !obj.CanAddr() {
- // Force it to a * type of the above; this involves a copy.
localObj := reflect.New(obj.Type())
localObj.Elem().Set(obj)
obj = localObj.Elem()
}
- fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
- if ok {
- // Invoke the provided saver.
- fns.invokeSave(obj.Addr(), m)
- } else if obj.NumField() == 0 {
- // Allow unregistered anonymous, empty structs.
- return &pb.Struct{}
- } else {
- // Propagate an error.
- panic(fmt.Errorf("unregistered type %T", obj.Interface()))
- }
-
- // Sort the underlying slice, and check for duplicates. This is done
- // once instead of on each add, because performing this sort once is
- // far more efficient.
- if len(m.data) > 1 {
- sort.Slice(m.data, func(i, j int) bool {
- return m.data[i].name < m.data[j].name
- })
- for i := range m.data {
- if i > 0 && m.data[i-1].name == m.data[i].name {
- panic(fmt.Errorf("duplicate name %s", m.data[i].name))
- }
+
+ // Prepare the value.
+ s := &wire.Struct{}
+ *dest = s
+
+ // Look the type up in the database.
+ te, ok := es.types.Lookup(obj.Type())
+ if te == nil {
+ if obj.NumField() == 0 {
+ // Allow unregistered anonymous, empty structs. This
+ // will just return success without ever invoking the
+ // passed function. This uses the immutable EmptyStruct
+ // variable to prevent an allocation in this case.
+ //
+ // Note that this mechanism does *not* work for
+ // interfaces in general. So you can't dispatch
+ // non-registered empty structs via interfaces because
+ // then they can't be restored.
+ s.Alloc(0)
+ return
}
+ // We need a SaverLoader for struct types.
+ Failf("struct %T does not implement SaverLoader", obj.Interface())
}
-
- // Encode the resulting fields.
- fields := make([]*pb.Field, 0, len(m.data))
- for _, e := range m.data {
- fields = append(fields, &pb.Field{
- Name: e.name,
- Value: e.object,
- })
+ if !ok {
+ // Queue the type to be serialized.
+ es.pendingTypes = append(es.pendingTypes, te.Type)
}
- // Return the encoded object.
- return &pb.Struct{Fields: fields}
+ // Invoke the provided saver.
+ s.TypeID = wire.TypeID(te.ID)
+ s.Alloc(len(te.Fields))
+ oe := objectEncoder{
+ es: es,
+ encoded: s,
+ }
+ es.stats.start(te.ID)
+ defer es.stats.done()
+ if sl, ok := obj.Addr().Interface().(SaverLoader); ok {
+ // Note: may be a registered empty struct which does not
+ // implement the saver/loader interfaces.
+ sl.StateSave(Sink{internal: oe})
+ }
}
// encodeArray encodes an array.
-func (es *encodeState) encodeArray(obj reflect.Value) *pb.Array {
- var (
- contents []*pb.Object
- )
- for i := 0; i < obj.Len(); i++ {
- entry := es.encodeObject(obj.Index(i), false, "[%d]", i)
- contents = append(contents, entry)
- }
- return &pb.Array{Contents: contents}
+func (es *encodeState) encodeArray(obj reflect.Value, dest *wire.Object) {
+ l := obj.Len()
+ a := &wire.Array{
+ Contents: make([]wire.Object, l),
+ }
+ *dest = a
+ for i := 0; i < l; i++ {
+ // We need to encode the full value because arrays are encoded
+ // using the type information from only the first element.
+ es.encodeObject(obj.Index(i), encodeAsValue, &a.Contents[i])
+ }
+}
+
+// findType recursively finds type information.
+func (es *encodeState) findType(typ reflect.Type) wire.TypeSpec {
+ // First: check if this is a proper type. It's possible for pointers,
+ // slices, arrays, maps, etc to all have some different type.
+ te, ok := es.types.Lookup(typ)
+ if te != nil {
+ if !ok {
+ // See encodeStruct.
+ es.pendingTypes = append(es.pendingTypes, te.Type)
+ }
+ return wire.TypeID(te.ID)
+ }
+
+ switch typ.Kind() {
+ case reflect.Ptr:
+ return &wire.TypeSpecPointer{
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Slice:
+ return &wire.TypeSpecSlice{
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Array:
+ return &wire.TypeSpecArray{
+ Count: wire.Uint(typ.Len()),
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Map:
+ return &wire.TypeSpecMap{
+ Key: es.findType(typ.Key()),
+ Value: es.findType(typ.Elem()),
+ }
+ default:
+ // After potentially chasing many pointers, the
+ // ultimate type of the object is not known.
+ Failf("type %q is not known", typ)
+ }
+ panic("unreachable")
}
// encodeInterface encodes an interface.
-//
-// Precondition: the value is not nil.
-func (es *encodeState) encodeInterface(obj reflect.Value) *pb.Interface {
- // Check for the nil interface.
- obj = reflect.ValueOf(obj.Interface())
+func (es *encodeState) encodeInterface(obj reflect.Value, dest *wire.Object) {
+ // Dereference the object.
+ obj = obj.Elem()
if !obj.IsValid() {
- return &pb.Interface{
- Type: "", // left alone in decode.
- Value: &pb.Object{Value: &pb.Object_RefValue{0}},
+ // Special case: the nil object.
+ *dest = &wire.Interface{
+ Type: wire.TypeSpecNil{},
+ Value: wire.Nil{},
}
+ return
}
- // We have an interface value here. How do we save that? We
- // resolve the underlying type and save it as a dispatchable.
- typName, ok := registeredTypes.lookupName(obj.Type())
- if !ok {
- panic(fmt.Errorf("type %s is not registered", obj.Type()))
+
+ // Encode underlying object.
+ i := &wire.Interface{
+ Type: es.findType(obj.Type()),
}
+ *dest = i
+ es.encodeObject(obj, encodeAsValue, &i.Value)
+}
- // Encode the object again.
- return &pb.Interface{
- Type: typName,
- Value: es.encodeObject(obj, false, ".(%s)", typName),
+// isPrimitive returns true if this is a primitive object, or a composite
+// object composed entirely of primitives.
+func isPrimitiveZero(typ reflect.Type) bool {
+ switch typ.Kind() {
+ case reflect.Ptr:
+ // Pointers are always treated as primitive types because we
+ // won't encode directly from here. Returning true here won't
+ // prevent the object from being encoded correctly.
+ return true
+ case reflect.Bool:
+ return true
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return true
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ return true
+ case reflect.Float32, reflect.Float64:
+ return true
+ case reflect.Complex64, reflect.Complex128:
+ return true
+ case reflect.String:
+ return true
+ case reflect.Slice:
+ // The slice itself a primitive, but not necessarily the array
+ // that points to. This is similar to a pointer.
+ return true
+ case reflect.Array:
+ // We cannot treat an array as a primitive, because it may be
+ // composed of structures or other things with side-effects.
+ return isPrimitiveZero(typ.Elem())
+ case reflect.Interface:
+ // Since we now that this type is the zero type, the interface
+ // value must be zero. Therefore this is primitive.
+ return true
+ case reflect.Struct:
+ return false
+ case reflect.Map:
+ // The isPrimitiveZero function is called only on zero-types to
+ // see if it's safe to serialize. Since a zero map has no
+ // elements, it is safe to treat as a primitive.
+ return true
+ default:
+ Failf("unknown type %q", typ.Name())
}
+ panic("unreachable")
}
-// encodeObject encodes an object.
-//
-// If mapAsValue is true, then a map will be encoded directly.
-func (es *encodeState) encodeObject(obj reflect.Value, mapAsValue bool, format string, param interface{}) (object *pb.Object) {
- es.push(false, format, param)
- es.stats.Add(obj)
- es.stats.Start(obj)
+// encodeStrategy is the strategy used for encodeObject.
+type encodeStrategy int
+const (
+ // encodeDefault means types are encoded normally as references.
+ encodeDefault encodeStrategy = iota
+
+ // encodeAsValue means that types will never take short-circuited and
+ // will always be encoded as a normal value.
+ encodeAsValue
+
+ // encodeMapAsValue means that even maps will be fully encoded.
+ encodeMapAsValue
+)
+
+// encodeObject encodes an object.
+func (es *encodeState) encodeObject(obj reflect.Value, how encodeStrategy, dest *wire.Object) {
+ if how == encodeDefault && isPrimitiveZero(obj.Type()) && obj.IsZero() {
+ *dest = wire.Nil{}
+ return
+ }
switch obj.Kind() {
+ case reflect.Ptr: // Fast path: first.
+ r := new(wire.Ref)
+ *dest = r
+ if obj.IsNil() {
+ // May be in an array or elsewhere such that a value is
+ // required. So we encode as a reference to the zero
+ // object, which does not exist. Note that this has to
+ // be handled correctly in the decode path as well.
+ return
+ }
+ es.resolve(obj, r)
case reflect.Bool:
- object = &pb.Object{Value: &pb.Object_BoolValue{obj.Bool()}}
+ *dest = wire.Bool(obj.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- object = &pb.Object{Value: &pb.Object_Int64Value{obj.Int()}}
+ *dest = wire.Int(obj.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
- object = &pb.Object{Value: &pb.Object_Uint64Value{obj.Uint()}}
- case reflect.Float32, reflect.Float64:
- object = &pb.Object{Value: &pb.Object_DoubleValue{obj.Float()}}
+ *dest = wire.Uint(obj.Uint())
+ case reflect.Float32:
+ *dest = wire.Float32(obj.Float())
+ case reflect.Float64:
+ *dest = wire.Float64(obj.Float())
+ case reflect.Complex64:
+ c := wire.Complex64(obj.Complex())
+ *dest = &c // Needs alloc.
+ case reflect.Complex128:
+ c := wire.Complex128(obj.Complex())
+ *dest = &c // Needs alloc.
+ case reflect.String:
+ s := wire.String(obj.String())
+ *dest = &s // Needs alloc.
case reflect.Array:
- switch obj.Type().Elem().Kind() {
- case reflect.Uint8:
- object = &pb.Object{Value: &pb.Object_ByteArrayValue{pbSlice(obj).Interface().([]byte)}}
- case reflect.Uint16:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := pbSlice(obj).Interface().([]uint16)
- t := make([]uint32, len(s))
- for i := range s {
- t[i] = uint32(s[i])
- }
- object = &pb.Object{Value: &pb.Object_Uint16ArrayValue{&pb.Uint16S{Values: t}}}
- case reflect.Uint32:
- object = &pb.Object{Value: &pb.Object_Uint32ArrayValue{&pb.Uint32S{Values: pbSlice(obj).Interface().([]uint32)}}}
- case reflect.Uint64:
- object = &pb.Object{Value: &pb.Object_Uint64ArrayValue{&pb.Uint64S{Values: pbSlice(obj).Interface().([]uint64)}}}
- case reflect.Uintptr:
- object = &pb.Object{Value: &pb.Object_UintptrArrayValue{&pb.Uintptrs{Values: pbSlice(obj).Interface().([]uint64)}}}
- case reflect.Int8:
- object = &pb.Object{Value: &pb.Object_Int8ArrayValue{&pb.Int8S{Values: pbSlice(obj).Interface().([]byte)}}}
- case reflect.Int16:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := pbSlice(obj).Interface().([]int16)
- t := make([]int32, len(s))
- for i := range s {
- t[i] = int32(s[i])
- }
- object = &pb.Object{Value: &pb.Object_Int16ArrayValue{&pb.Int16S{Values: t}}}
- case reflect.Int32:
- object = &pb.Object{Value: &pb.Object_Int32ArrayValue{&pb.Int32S{Values: pbSlice(obj).Interface().([]int32)}}}
- case reflect.Int64:
- object = &pb.Object{Value: &pb.Object_Int64ArrayValue{&pb.Int64S{Values: pbSlice(obj).Interface().([]int64)}}}
- case reflect.Bool:
- object = &pb.Object{Value: &pb.Object_BoolArrayValue{&pb.Bools{Values: pbSlice(obj).Interface().([]bool)}}}
- case reflect.Float32:
- object = &pb.Object{Value: &pb.Object_Float32ArrayValue{&pb.Float32S{Values: pbSlice(obj).Interface().([]float32)}}}
- case reflect.Float64:
- object = &pb.Object{Value: &pb.Object_Float64ArrayValue{&pb.Float64S{Values: pbSlice(obj).Interface().([]float64)}}}
- default:
- object = &pb.Object{Value: &pb.Object_ArrayValue{es.encodeArray(obj)}}
- }
+ es.encodeArray(obj, dest)
case reflect.Slice:
- if obj.IsNil() || obj.Cap() == 0 {
- // Handled specially in decode; store as nil value.
- object = &pb.Object{Value: &pb.Object_RefValue{0}}
- } else {
- // Serialize a slice as the array plus length and capacity.
- object = &pb.Object{Value: &pb.Object_SliceValue{&pb.Slice{
- Capacity: uint32(obj.Cap()),
- Length: uint32(obj.Len()),
- RefValue: es.register(arrayFromSlice(obj)),
- }}}
+ s := &wire.Slice{
+ Capacity: wire.Uint(obj.Cap()),
+ Length: wire.Uint(obj.Len()),
}
- case reflect.String:
- object = &pb.Object{Value: &pb.Object_StringValue{[]byte(obj.String())}}
- case reflect.Ptr:
+ *dest = s
+ // Note that we do need to provide a wire.Slice type here as
+ // how is not encodeDefault. If this were the case, then it
+ // would have been caught by the IsZero check above and we
+ // would have just used wire.Nil{}.
if obj.IsNil() {
- // Handled specially in decode; store as a nil value.
- object = &pb.Object{Value: &pb.Object_RefValue{0}}
- } else {
- es.push(true /* dereference */, "", nil)
- object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}}
- es.pop()
+ return
}
+ // Slices need pointer resolution.
+ es.resolve(arrayFromSlice(obj), &s.Ref)
case reflect.Interface:
- // We don't check for IsNil here, as we want to encode type
- // information. The case of the empty interface (no type, no
- // value) is handled by encodeInteface.
- object = &pb.Object{Value: &pb.Object_InterfaceValue{es.encodeInterface(obj)}}
+ es.encodeInterface(obj, dest)
case reflect.Struct:
- object = &pb.Object{Value: &pb.Object_StructValue{es.encodeStruct(obj)}}
+ es.encodeStruct(obj, dest)
case reflect.Map:
- if obj.IsNil() {
- // Handled specially in decode; store as a nil value.
- object = &pb.Object{Value: &pb.Object_RefValue{0}}
- } else if mapAsValue {
- // Encode the map directly.
- object = &pb.Object{Value: &pb.Object_MapValue{es.encodeMap(obj)}}
- } else {
- // Encode a reference to the map.
- //
- // Remove the map object count here to avoid double
- // counting, as this object will be counted again when
- // it gets processed later. We do not add a reference
- // count as the reference is artificial.
- es.stats.Remove(obj)
- object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}}
+ if how == encodeMapAsValue {
+ es.encodeMap(obj, dest)
+ return
}
+ r := new(wire.Ref)
+ *dest = r
+ es.resolve(obj, r)
default:
- panic(fmt.Errorf("unknown primitive %#v", obj.Interface()))
+ Failf("unknown object %#v", obj.Interface())
+ panic("unreachable")
}
-
- es.stats.Done()
- es.pop()
- return
}
-// Serialize serializes the object state.
-//
-// This function may panic and should be run in safely().
-func (es *encodeState) Serialize(obj reflect.Value) {
- es.register(obj.Addr())
-
- // Pop off the list until we're done.
- for es.pending.Len() > 0 {
- e := es.pending.Front()
-
- // Extract the queued object.
- qo := e.Value.(queuedObject)
- es.stats.Start(qo.obj)
+// Save serializes the object graph rooted at obj.
+func (es *encodeState) Save(obj reflect.Value) {
+ es.stats.init()
+ defer es.stats.fini(func(id typeID) string {
+ return es.pendingTypes[id-1].Name
+ })
+
+ // Resolve the first object, which should queue a pile of additional
+ // objects on the pending list. All queued objects should be fully
+ // resolved, and we should be able to serialize after this call.
+ var root wire.Ref
+ es.resolve(obj.Addr(), &root)
+
+ // Encode the graph.
+ var oes *objectEncodeState
+ if err := safely(func() {
+ for oes = es.deferred.Front(); oes != nil; oes = es.deferred.Front() {
+ // Remove and encode the object. Note that as a result
+ // of this encoding, the object may be enqueued on the
+ // deferred list yet again. That's expected, and why it
+ // is removed first.
+ es.deferred.Remove(oes)
+ es.encodeObject(oes.obj, oes.how, &oes.encoded)
+ }
+ }); err != nil {
+ // Include the object in the error message.
+ Failf("encoding error at object %#v: %w", oes.obj.Interface(), err)
+ }
- es.pending.Remove(e)
+ // Check that items are pending.
+ if es.pending.Front() == nil {
+ Failf("pending is empty?")
+ }
- es.from = &qo.path
- o := es.encodeObject(qo.obj, true, "", nil)
+ // Write the header with the number of objects. Note that there is no
+ // way that es.lastID could conflict with objectID, which would
+ // indicate that an impossibly large encoding.
+ if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil {
+ Failf("error writing header: %w", err)
+ }
- // Emit to our output stream.
- if err := es.writeObject(qo.id, o); err != nil {
- panic(err)
+ // Serialize all pending types and pending objects. Note that we don't
+ // bother removing from this list as we walk it because that just
+ // wastes time. It will not change after this point.
+ var id objectID
+ if err := safely(func() {
+ for _, wt := range es.pendingTypes {
+ // Encode the type.
+ wire.Save(es.w, &wt)
}
+ for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() {
+ id++ // First object is 1.
+ if oes.id != id {
+ Failf("expected id %d, got %d", id, oes.id)
+ }
- // Mark as done.
- es.done.PushBack(e)
- es.stats.Done()
+ // Marshall the object.
+ wire.Save(es.w, oes.encoded)
+ }
+ }); err != nil {
+ // Include the object and the error.
+ Failf("error serializing object %#v: %w", oes.encoded, err)
}
- // Write a zero-length terminal at the end; this is a sanity check
- // applied at decode time as well (see decode.go).
- if err := WriteHeader(es.w, 0, false); err != nil {
- panic(err)
+ // Check what we wrote.
+ if id != es.lastID {
+ Failf("expected %d objects, wrote %d", es.lastID, id)
}
}
+// objectFlag indicates that the length is a # of objects, rather than a raw
+// byte length. When this is set on a length header in the stream, it may be
+// decoded appropriately.
+const objectFlag uint64 = 1 << 63
+
// WriteHeader writes a header.
//
// Each object written to the statefile should be prefixed with a header. In
// order to generate statefiles that play nicely with debugging tools, raw
// writes should be prefixed with a header with object set to false and the
// appropriate length. This will allow tools to skip these regions.
-func WriteHeader(w io.Writer, length uint64, object bool) error {
- // The lowest-order bit encodes whether this is a valid object. This is
- // a purely internal convention, but allows the object flag to be
- // returned from ReadHeader.
- length = length << 1
+func WriteHeader(w wire.Writer, length uint64, object bool) error {
+ // Sanity check the length.
+ if length&objectFlag != 0 {
+ Failf("impossibly huge length: %d", length)
+ }
if object {
- length |= 0x1
+ length |= objectFlag
}
// Write a header.
- var hdr [32]byte
- encodedLen := binary.PutUvarint(hdr[:], length)
- for done := 0; done < encodedLen; {
- n, err := w.Write(hdr[done:encodedLen])
- done += n
- if n == 0 && err != nil {
- return err
- }
- }
-
- return nil
+ return safely(func() {
+ wire.SaveUint(w, length)
+ })
}
-// writeObject writes an object to the stream.
-func (es *encodeState) writeObject(id uint64, obj *pb.Object) error {
- // Marshal the proto.
- buf, err := proto.Marshal(obj)
- if err != nil {
- return err
- }
+// pendingMapper is for the pending list.
+type pendingMapper struct{}
- // Write the object header.
- if err := WriteHeader(es.w, uint64(len(buf)), true); err != nil {
- return err
- }
+func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry }
- // Write the object.
- for done := 0; done < len(buf); {
- n, err := es.w.Write(buf[done:])
- done += n
- if n == 0 && err != nil {
- return err
- }
- }
+// deferredMapper is for the deferred list.
+type deferredMapper struct{}
- return nil
-}
+func (deferredMapper) linkerFor(oes *objectEncodeState) *deferredEntry { return &oes.deferredEntry }
// addrSetFunctions is used by addrSet.
type addrSetFunctions struct{}
@@ -458,13 +818,24 @@ func (addrSetFunctions) MaxKey() uintptr {
return ^uintptr(0)
}
-func (addrSetFunctions) ClearValue(val *reflect.Value) {
+func (addrSetFunctions) ClearValue(val **objectEncodeState) {
+ *val = nil
}
-func (addrSetFunctions) Merge(_ addrRange, val1 reflect.Value, _ addrRange, val2 reflect.Value) (reflect.Value, bool) {
- return val1, val1 == val2
+func (addrSetFunctions) Merge(r1 addrRange, val1 *objectEncodeState, r2 addrRange, val2 *objectEncodeState) (*objectEncodeState, bool) {
+ if val1.obj == val2.obj {
+ // This, should never happen. It would indicate that the same
+ // object exists in two non-contiguous address ranges. Note
+ // that this assertion can only be triggered if the race
+ // detector is enabled.
+ Failf("unexpected merge in addrSet @ %v and %v: %#v and %#v", r1, r2, val1.obj, val2.obj)
+ }
+ // Reject the merge.
+ return val1, false
}
-func (addrSetFunctions) Split(_ addrRange, val reflect.Value, _ uintptr) (reflect.Value, reflect.Value) {
- return val, val
+func (addrSetFunctions) Split(r addrRange, val *objectEncodeState, _ uintptr) (*objectEncodeState, *objectEncodeState) {
+ // A split should never happen: we don't remove ranges.
+ Failf("unexpected split in addrSet @ %v: %#v", r, val.obj)
+ panic("unreachable")
}
diff --git a/pkg/state/encode_unsafe.go b/pkg/state/encode_unsafe.go
index 457e6dbb7..e0dad83b4 100644
--- a/pkg/state/encode_unsafe.go
+++ b/pkg/state/encode_unsafe.go
@@ -31,51 +31,3 @@ func arrayFromSlice(obj reflect.Value) reflect.Value {
reflect.ArrayOf(obj.Cap(), obj.Type().Elem()),
unsafe.Pointer(obj.Pointer()))
}
-
-// pbSlice returns a protobuf-supported slice of the array and erase the
-// original element type (which could be a defined type or non-supported type).
-func pbSlice(obj reflect.Value) reflect.Value {
- var typ reflect.Type
- switch obj.Type().Elem().Kind() {
- case reflect.Uint8:
- typ = reflect.TypeOf(byte(0))
- case reflect.Uint16:
- typ = reflect.TypeOf(uint16(0))
- case reflect.Uint32:
- typ = reflect.TypeOf(uint32(0))
- case reflect.Uint64:
- typ = reflect.TypeOf(uint64(0))
- case reflect.Uintptr:
- typ = reflect.TypeOf(uint64(0))
- case reflect.Int8:
- typ = reflect.TypeOf(byte(0))
- case reflect.Int16:
- typ = reflect.TypeOf(int16(0))
- case reflect.Int32:
- typ = reflect.TypeOf(int32(0))
- case reflect.Int64:
- typ = reflect.TypeOf(int64(0))
- case reflect.Bool:
- typ = reflect.TypeOf(bool(false))
- case reflect.Float32:
- typ = reflect.TypeOf(float32(0))
- case reflect.Float64:
- typ = reflect.TypeOf(float64(0))
- default:
- panic("slice element is not of basic value type")
- }
- return reflect.NewAt(
- reflect.ArrayOf(obj.Len(), typ),
- unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()),
- ).Elem().Slice(0, obj.Len())
-}
-
-func castSlice(obj reflect.Value, elemTyp reflect.Type) reflect.Value {
- if obj.Type().Elem().Size() != elemTyp.Size() {
- panic("cannot cast slice into other element type of different size")
- }
- return reflect.NewAt(
- reflect.ArrayOf(obj.Len(), elemTyp),
- unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()),
- ).Elem()
-}
diff --git a/pkg/state/map.go b/pkg/state/map.go
deleted file mode 100644
index 4f3ebb0da..000000000
--- a/pkg/state/map.go
+++ /dev/null
@@ -1,232 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package state
-
-import (
- "context"
- "fmt"
- "reflect"
- "sort"
- "sync"
-
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
-)
-
-// entry is a single map entry.
-type entry struct {
- name string
- object *pb.Object
-}
-
-// internalMap is the internal Map state.
-//
-// These are recycled via a pool to avoid churn.
-type internalMap struct {
- // es is encodeState.
- es *encodeState
-
- // ds is decodeState.
- ds *decodeState
-
- // os is current object being decoded.
- //
- // This will always be nil during encode.
- os *objectState
-
- // data stores the encoded values.
- data []entry
-}
-
-var internalMapPool = sync.Pool{
- New: func() interface{} {
- return new(internalMap)
- },
-}
-
-// newInternalMap returns a cached map.
-func newInternalMap(es *encodeState, ds *decodeState, os *objectState) *internalMap {
- m := internalMapPool.Get().(*internalMap)
- m.es = es
- m.ds = ds
- m.os = os
- if m.data != nil {
- m.data = m.data[:0]
- }
- return m
-}
-
-// Map is a generic state container.
-//
-// This is the object passed to Save and Load in order to store their state.
-//
-// Detailed documentation is available in individual methods.
-type Map struct {
- *internalMap
-}
-
-// Save adds the given object to the map.
-//
-// You should pass always pointers to the object you are saving. For example:
-//
-// type X struct {
-// A int
-// B *int
-// }
-//
-// func (x *X) Save(m Map) {
-// m.Save("A", &x.A)
-// m.Save("B", &x.B)
-// }
-//
-// func (x *X) Load(m Map) {
-// m.Load("A", &x.A)
-// m.Load("B", &x.B)
-// }
-func (m Map) Save(name string, objPtr interface{}) {
- m.save(name, reflect.ValueOf(objPtr).Elem(), ".%s")
-}
-
-// SaveValue adds the given object value to the map.
-//
-// This should be used for values where pointers are not available, or casts
-// are required during Save/Load.
-//
-// For example, if we want to cast external package type P.Foo to int64:
-//
-// type X struct {
-// A P.Foo
-// }
-//
-// func (x *X) Save(m Map) {
-// m.SaveValue("A", int64(x.A))
-// }
-//
-// func (x *X) Load(m Map) {
-// m.LoadValue("A", new(int64), func(x interface{}) {
-// x.A = P.Foo(x.(int64))
-// })
-// }
-func (m Map) SaveValue(name string, obj interface{}) {
- m.save(name, reflect.ValueOf(obj), ".(value %s)")
-}
-
-// save is helper for the above. It takes the name of value to save the field
-// to, the field object (obj), and a format string that specifies how the
-// field's saving logic is dispatched from the struct (normal, value, etc.). The
-// format string should expect one string parameter, which is the name of the
-// field.
-func (m Map) save(name string, obj reflect.Value, format string) {
- if m.es == nil {
- // Not currently encoding.
- m.Failf("no encode state for %q", name)
- }
-
- // Attempt the encode.
- //
- // These are sorted at the end, after all objects are added and will be
- // sorted and checked for duplicates (see encodeStruct).
- m.data = append(m.data, entry{
- name: name,
- object: m.es.encodeObject(obj, false, format, name),
- })
-}
-
-// Load loads the given object from the map.
-//
-// See Save for an example.
-func (m Map) Load(name string, objPtr interface{}) {
- m.load(name, reflect.ValueOf(objPtr), false, nil, ".%s")
-}
-
-// LoadWait loads the given objects from the map, and marks it as requiring all
-// AfterLoad executions to complete prior to running this object's AfterLoad.
-//
-// See Save for an example.
-func (m Map) LoadWait(name string, objPtr interface{}) {
- m.load(name, reflect.ValueOf(objPtr), true, nil, ".(wait %s)")
-}
-
-// LoadValue loads the given object value from the map.
-//
-// See SaveValue for an example.
-func (m Map) LoadValue(name string, objPtr interface{}, fn func(interface{})) {
- o := reflect.ValueOf(objPtr)
- m.load(name, o, true, func() { fn(o.Elem().Interface()) }, ".(value %s)")
-}
-
-// load is helper for the above. It takes the name of value to load the field
-// from, the target field pointer (objPtr), whether load completion of the
-// struct depends on the field's load completion (wait), the load completion
-// logic (fn), and a format string that specifies how the field's loading logic
-// is dispatched from the struct (normal, wait, value, etc.). The format string
-// should expect one string parameter, which is the name of the field.
-func (m Map) load(name string, objPtr reflect.Value, wait bool, fn func(), format string) {
- if m.ds == nil {
- // Not currently decoding.
- m.Failf("no decode state for %q", name)
- }
-
- // Find the object.
- //
- // These are sorted up front (and should appear in the state file
- // sorted as well), so we can do a binary search here to ensure that
- // large structs don't behave badly.
- i := sort.Search(len(m.data), func(i int) bool {
- return m.data[i].name >= name
- })
- if i >= len(m.data) || m.data[i].name != name {
- // There is no data for this name?
- m.Failf("no data found for %q", name)
- }
-
- // Perform the decode.
- m.ds.decodeObject(m.os, objPtr.Elem(), m.data[i].object, format, name)
- if wait {
- // Mark this individual object a blocker.
- m.ds.waitObject(m.os, m.data[i].object, fn)
- }
-}
-
-// Failf fails the save or restore with the provided message. Processing will
-// stop after calling Failf, as the state package uses a panic & recover
-// mechanism for state errors. You should defer any cleanup required.
-func (m Map) Failf(format string, args ...interface{}) {
- panic(fmt.Errorf(format, args...))
-}
-
-// AfterLoad schedules a function execution when all objects have been allocated
-// and their automated loading and customized load logic have been executed. fn
-// will not be executed until all of current object's dependencies' AfterLoad()
-// logic, if exist, have been executed.
-func (m Map) AfterLoad(fn func()) {
- if m.ds == nil {
- // Not currently decoding.
- m.Failf("not decoding")
- }
-
- // Queue the local callback; this will execute when all of the above
- // data dependencies have been cleared.
- m.os.callbacks = append(m.os.callbacks, fn)
-}
-
-// Context returns the current context object.
-func (m Map) Context() context.Context {
- if m.es != nil {
- return m.es.ctx
- } else if m.ds != nil {
- return m.ds.ctx
- }
- return context.Background() // No context.
-}
diff --git a/pkg/state/object.proto b/pkg/state/object.proto
deleted file mode 100644
index 5ebcfb151..000000000
--- a/pkg/state/object.proto
+++ /dev/null
@@ -1,140 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto3";
-
-package gvisor.state.statefile;
-
-// Slice is a slice value.
-message Slice {
- uint32 length = 1;
- uint32 capacity = 2;
- uint64 ref_value = 3;
-}
-
-// Array is an array value.
-message Array {
- repeated Object contents = 1;
-}
-
-// Map is a map value.
-message Map {
- repeated Object keys = 1;
- repeated Object values = 2;
-}
-
-// Interface is an interface value.
-message Interface {
- string type = 1;
- Object value = 2;
-}
-
-// Struct is a basic composite value.
-message Struct {
- repeated Field fields = 1;
-}
-
-// Field encodes a single field.
-message Field {
- string name = 1;
- Object value = 2;
-}
-
-// Uint16s encodes an uint16 array. To be used inside oneof structure.
-message Uint16s {
- // There is no 16-bit type in protobuf so we use variable length 32-bit here.
- repeated uint32 values = 1;
-}
-
-// Uint32s encodes an uint32 array. To be used inside oneof structure.
-message Uint32s {
- repeated fixed32 values = 1;
-}
-
-// Uint64s encodes an uint64 array. To be used inside oneof structure.
-message Uint64s {
- repeated fixed64 values = 1;
-}
-
-// Uintptrs encodes an uintptr array. To be used inside oneof structure.
-message Uintptrs {
- repeated fixed64 values = 1;
-}
-
-// Int8s encodes an int8 array. To be used inside oneof structure.
-message Int8s {
- bytes values = 1;
-}
-
-// Int16s encodes an int16 array. To be used inside oneof structure.
-message Int16s {
- // There is no 16-bit type in protobuf so we use variable length 32-bit here.
- repeated int32 values = 1;
-}
-
-// Int32s encodes an int32 array. To be used inside oneof structure.
-message Int32s {
- repeated sfixed32 values = 1;
-}
-
-// Int64s encodes an int64 array. To be used inside oneof structure.
-message Int64s {
- repeated sfixed64 values = 1;
-}
-
-// Bools encodes a boolean array. To be used inside oneof structure.
-message Bools {
- repeated bool values = 1;
-}
-
-// Float64s encodes a float64 array. To be used inside oneof structure.
-message Float64s {
- repeated double values = 1;
-}
-
-// Float32s encodes a float32 array. To be used inside oneof structure.
-message Float32s {
- repeated float values = 1;
-}
-
-// Object are primitive encodings.
-//
-// Note that ref_value references an Object.id, below.
-message Object {
- oneof value {
- bool bool_value = 1;
- bytes string_value = 2;
- int64 int64_value = 3;
- uint64 uint64_value = 4;
- double double_value = 5;
- uint64 ref_value = 6;
- Slice slice_value = 7;
- Array array_value = 8;
- Interface interface_value = 9;
- Struct struct_value = 10;
- Map map_value = 11;
- bytes byte_array_value = 12;
- Uint16s uint16_array_value = 13;
- Uint32s uint32_array_value = 14;
- Uint64s uint64_array_value = 15;
- Uintptrs uintptr_array_value = 16;
- Int8s int8_array_value = 17;
- Int16s int16_array_value = 18;
- Int32s int32_array_value = 19;
- Int64s int64_array_value = 20;
- Bools bool_array_value = 21;
- Float64s float64_array_value = 22;
- Float32s float32_array_value = 23;
- }
-}
diff --git a/pkg/state/pretty/BUILD b/pkg/state/pretty/BUILD
new file mode 100644
index 000000000..d053802f7
--- /dev/null
+++ b/pkg/state/pretty/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pretty",
+ srcs = ["pretty.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/state",
+ "//pkg/state/wire",
+ ],
+)
diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go
new file mode 100644
index 000000000..cf37aaa49
--- /dev/null
+++ b/pkg/state/pretty/pretty.go
@@ -0,0 +1,273 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pretty is a pretty-printer for state streams.
+package pretty
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "reflect"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+func formatRef(x *wire.Ref, graph uint64, html bool) string {
+ baseRef := fmt.Sprintf("g%dr%d", graph, x.Root)
+ fullRef := baseRef
+ if len(x.Dots) > 0 {
+ // See wire.Ref; Type valid if Dots non-zero.
+ typ, _ := formatType(x.Type, graph, html)
+ var buf strings.Builder
+ buf.WriteString("(*")
+ buf.WriteString(typ)
+ buf.WriteString(")(")
+ buf.WriteString(baseRef)
+ for _, component := range x.Dots {
+ switch v := component.(type) {
+ case *wire.FieldName:
+ buf.WriteString(".")
+ buf.WriteString(string(*v))
+ case wire.Index:
+ buf.WriteString(fmt.Sprintf("[%d]", v))
+ default:
+ panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component)))
+ }
+ }
+ buf.WriteString(")")
+ fullRef = buf.String()
+ }
+ if html {
+ return fmt.Sprintf("<a href=\"#%s\">%s</a>", baseRef, fullRef)
+ }
+ return fullRef
+}
+
+func formatType(t wire.TypeSpec, graph uint64, html bool) (string, bool) {
+ switch x := t.(type) {
+ case wire.TypeID:
+ base := fmt.Sprintf("g%dt%d", graph, x)
+ if html {
+ return fmt.Sprintf("<a href=\"#%s\">%s</a>", base, base), true
+ }
+ return fmt.Sprintf("%s", base), true
+ case wire.TypeSpecNil:
+ return "", false // Only nil type.
+ case *wire.TypeSpecPointer:
+ element, _ := formatType(x.Type, graph, html)
+ return fmt.Sprintf("(*%s)", element), true
+ case *wire.TypeSpecArray:
+ element, _ := formatType(x.Type, graph, html)
+ return fmt.Sprintf("[%d](%s)", x.Count, element), true
+ case *wire.TypeSpecSlice:
+ element, _ := formatType(x.Type, graph, html)
+ return fmt.Sprintf("([]%s)", element), true
+ case *wire.TypeSpecMap:
+ key, _ := formatType(x.Key, graph, html)
+ value, _ := formatType(x.Value, graph, html)
+ return fmt.Sprintf("(map[%s]%s)", key, value), true
+ default:
+ panic(fmt.Sprintf("unreachable: unknown type %T", t))
+ }
+}
+
+// format formats a single object, for pretty-printing. It also returns whether
+// the value is a non-zero value.
+func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bool) {
+ switch x := encoded.(type) {
+ case wire.Nil:
+ return "nil", false
+ case *wire.String:
+ return fmt.Sprintf("%q", *x), *x != ""
+ case *wire.Complex64:
+ return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0
+ case *wire.Complex128:
+ return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0
+ case *wire.Ref:
+ return formatRef(x, graph, html), x.Root != 0
+ case *wire.Type:
+ tabs := "\n" + strings.Repeat("\t", depth)
+ items := make([]string, 0, len(x.Fields)+2)
+ items = append(items, fmt.Sprintf("type %s {", x.Name))
+ for i := 0; i < len(x.Fields); i++ {
+ items = append(items, fmt.Sprintf("\t%d: %s,", i, x.Fields[i]))
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), true // No zero value.
+ case *wire.Slice:
+ return fmt.Sprintf("%s{len:%d,cap:%d}", formatRef(&x.Ref, graph, html), x.Length, x.Capacity), x.Capacity != 0
+ case *wire.Array:
+ if len(x.Contents) == 0 {
+ return "[]", false
+ }
+ items := make([]string, 0, len(x.Contents)+2)
+ zeros := make([]string, 0) // used to eliminate zero entries.
+ items = append(items, "[")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ for i := 0; i < len(x.Contents); i++ {
+ item, ok := format(graph, depth+1, x.Contents[i], html)
+ if !ok {
+ zeros = append(zeros, fmt.Sprintf("\t%s,", item))
+ continue
+ }
+ if len(zeros) > 0 {
+ items = append(items, zeros...)
+ zeros = nil
+ }
+ items = append(items, fmt.Sprintf("\t%s,", item))
+ }
+ if len(zeros) > 0 {
+ items = append(items, fmt.Sprintf("\t... (%d zeros),", len(zeros)))
+ }
+ items = append(items, "]")
+ return strings.Join(items, tabs), len(zeros) < len(x.Contents)
+ case *wire.Struct:
+ typ, _ := formatType(x.TypeID, graph, html)
+ if x.Fields() == 0 {
+ return fmt.Sprintf("struct[%s]{}", typ), false
+ }
+ items := make([]string, 0, 2)
+ items = append(items, fmt.Sprintf("struct[%s]{", typ))
+ tabs := "\n" + strings.Repeat("\t", depth)
+ allZero := true
+ for i := 0; i < x.Fields(); i++ {
+ element, ok := format(graph, depth+1, *x.Field(i), html)
+ allZero = allZero && !ok
+ items = append(items, fmt.Sprintf("\t%d: %s,", i, element))
+ i++
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), !allZero
+ case *wire.Map:
+ if len(x.Keys) == 0 {
+ return "map{}", false
+ }
+ items := make([]string, 0, len(x.Keys)+2)
+ items = append(items, "map{")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ for i := 0; i < len(x.Keys); i++ {
+ key, _ := format(graph, depth+1, x.Keys[i], html)
+ value, _ := format(graph, depth+1, x.Values[i], html)
+ items = append(items, fmt.Sprintf("\t%s: %s,", key, value))
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), true
+ case *wire.Interface:
+ typ, typOk := formatType(x.Type, graph, html)
+ element, elementOk := format(graph, depth+1, x.Value, html)
+ return fmt.Sprintf("interface[%s]{%s}", typ, element), typOk || elementOk
+ default:
+ // Must be a primitive; use reflection.
+ return fmt.Sprintf("%v", encoded), true
+ }
+}
+
+// printStream is the basic print implementation.
+func printStream(w io.Writer, r wire.Reader, html bool) (err error) {
+ // current graph ID.
+ var graph uint64
+
+ if html {
+ fmt.Fprintf(w, "<pre>")
+ defer fmt.Fprintf(w, "</pre>")
+ }
+
+ defer func() {
+ if r := recover(); r != nil {
+ if rErr, ok := r.(error); ok {
+ err = rErr // Override return.
+ return
+ }
+ panic(r) // Propagate.
+ }
+ }()
+
+ for {
+ // Find the first object to begin generation.
+ length, object, err := state.ReadHeader(r)
+ if err == io.EOF {
+ // Nothing else to do.
+ break
+ } else if err != nil {
+ return err
+ }
+ if !object {
+ graph++ // Increment the graph.
+ if length > 0 {
+ fmt.Fprintf(w, "(%d bytes non-object data)\n", length)
+ io.Copy(ioutil.Discard, &io.LimitedReader{
+ R: r,
+ N: int64(length),
+ })
+ }
+ continue
+ }
+
+ // Read & unmarshal the object.
+ //
+ // Note that this loop must match the general structure of the
+ // loop in decode.go. But we don't register type information,
+ // etc. and just print the raw structures.
+ var (
+ oid uint64 = 1
+ tid uint64 = 1
+ )
+ for oid <= length {
+ // Unmarshal the object.
+ encoded := wire.Load(r)
+
+ // Is this a type?
+ if _, ok := encoded.(*wire.Type); ok {
+ str, _ := format(graph, 0, encoded, html)
+ tag := fmt.Sprintf("g%dt%d", graph, tid)
+ if html {
+ // See below.
+ tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
+ }
+ if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
+ return err
+ }
+ tid++
+ continue
+ }
+
+ // Format the node.
+ str, _ := format(graph, 0, encoded, html)
+ tag := fmt.Sprintf("g%dr%d", graph, oid)
+ if html {
+ // Create a little tag with an anchor next to it for linking.
+ tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
+ }
+ if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
+ return err
+ }
+ oid++
+ }
+ }
+
+ return nil
+}
+
+// PrintText reads the stream from r and prints text to w.
+func PrintText(w io.Writer, r wire.Reader) error {
+ return printStream(w, r, false /* html */)
+}
+
+// PrintHTML reads the stream from r and prints html to w.
+func PrintHTML(w io.Writer, r wire.Reader) error {
+ return printStream(w, r, true /* html */)
+}
diff --git a/pkg/state/printer.go b/pkg/state/printer.go
deleted file mode 100644
index 3ce18242f..000000000
--- a/pkg/state/printer.go
+++ /dev/null
@@ -1,251 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package state
-
-import (
- "fmt"
- "io"
- "io/ioutil"
- "reflect"
- "strings"
-
- "github.com/golang/protobuf/proto"
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
-)
-
-// format formats a single object, for pretty-printing. It also returns whether
-// the value is a non-zero value.
-func format(graph uint64, depth int, object *pb.Object, html bool) (string, bool) {
- switch x := object.GetValue().(type) {
- case *pb.Object_BoolValue:
- return fmt.Sprintf("%t", x.BoolValue), x.BoolValue != false
- case *pb.Object_StringValue:
- return fmt.Sprintf("\"%s\"", string(x.StringValue)), len(x.StringValue) != 0
- case *pb.Object_Int64Value:
- return fmt.Sprintf("%d", x.Int64Value), x.Int64Value != 0
- case *pb.Object_Uint64Value:
- return fmt.Sprintf("%du", x.Uint64Value), x.Uint64Value != 0
- case *pb.Object_DoubleValue:
- return fmt.Sprintf("%f", x.DoubleValue), x.DoubleValue != 0.0
- case *pb.Object_RefValue:
- if x.RefValue == 0 {
- return "nil", false
- }
- ref := fmt.Sprintf("g%dr%d", graph, x.RefValue)
- if html {
- ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref)
- }
- return ref, true
- case *pb.Object_SliceValue:
- if x.SliceValue.RefValue == 0 {
- return "nil", false
- }
- ref := fmt.Sprintf("g%dr%d", graph, x.SliceValue.RefValue)
- if html {
- ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref)
- }
- return fmt.Sprintf("%s[:%d:%d]", ref, x.SliceValue.Length, x.SliceValue.Capacity), true
- case *pb.Object_ArrayValue:
- if len(x.ArrayValue.Contents) == 0 {
- return "[]", false
- }
- items := make([]string, 0, len(x.ArrayValue.Contents)+2)
- zeros := make([]string, 0) // used to eliminate zero entries.
- items = append(items, "[")
- tabs := "\n" + strings.Repeat("\t", depth)
- for i := 0; i < len(x.ArrayValue.Contents); i++ {
- item, ok := format(graph, depth+1, x.ArrayValue.Contents[i], html)
- if ok {
- if len(zeros) > 0 {
- items = append(items, zeros...)
- zeros = nil
- }
- items = append(items, fmt.Sprintf("\t%s,", item))
- } else {
- zeros = append(zeros, fmt.Sprintf("\t%s,", item))
- }
- }
- if len(zeros) > 0 {
- items = append(items, fmt.Sprintf("\t... (%d zeros),", len(zeros)))
- }
- items = append(items, "]")
- return strings.Join(items, tabs), len(zeros) < len(x.ArrayValue.Contents)
- case *pb.Object_StructValue:
- if len(x.StructValue.Fields) == 0 {
- return "struct{}", false
- }
- items := make([]string, 0, len(x.StructValue.Fields)+2)
- items = append(items, "struct{")
- tabs := "\n" + strings.Repeat("\t", depth)
- allZero := true
- for _, field := range x.StructValue.Fields {
- element, ok := format(graph, depth+1, field.Value, html)
- allZero = allZero && !ok
- items = append(items, fmt.Sprintf("\t%s: %s,", field.Name, element))
- }
- items = append(items, "}")
- return strings.Join(items, tabs), !allZero
- case *pb.Object_MapValue:
- if len(x.MapValue.Keys) == 0 {
- return "map{}", false
- }
- items := make([]string, 0, len(x.MapValue.Keys)+2)
- items = append(items, "map{")
- tabs := "\n" + strings.Repeat("\t", depth)
- for i := 0; i < len(x.MapValue.Keys); i++ {
- key, _ := format(graph, depth+1, x.MapValue.Keys[i], html)
- value, _ := format(graph, depth+1, x.MapValue.Values[i], html)
- items = append(items, fmt.Sprintf("\t%s: %s,", key, value))
- }
- items = append(items, "}")
- return strings.Join(items, tabs), true
- case *pb.Object_InterfaceValue:
- if x.InterfaceValue.Type == "" {
- return "interface(nil){}", false
- }
- element, _ := format(graph, depth+1, x.InterfaceValue.Value, html)
- return fmt.Sprintf("interface(\"%s\"){%s}", x.InterfaceValue.Type, element), true
- case *pb.Object_ByteArrayValue:
- return printArray(reflect.ValueOf(x.ByteArrayValue))
- case *pb.Object_Uint16ArrayValue:
- return printArray(reflect.ValueOf(x.Uint16ArrayValue.Values))
- case *pb.Object_Uint32ArrayValue:
- return printArray(reflect.ValueOf(x.Uint32ArrayValue.Values))
- case *pb.Object_Uint64ArrayValue:
- return printArray(reflect.ValueOf(x.Uint64ArrayValue.Values))
- case *pb.Object_UintptrArrayValue:
- return printArray(castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0))))
- case *pb.Object_Int8ArrayValue:
- return printArray(castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0))))
- case *pb.Object_Int16ArrayValue:
- return printArray(reflect.ValueOf(x.Int16ArrayValue.Values))
- case *pb.Object_Int32ArrayValue:
- return printArray(reflect.ValueOf(x.Int32ArrayValue.Values))
- case *pb.Object_Int64ArrayValue:
- return printArray(reflect.ValueOf(x.Int64ArrayValue.Values))
- case *pb.Object_BoolArrayValue:
- return printArray(reflect.ValueOf(x.BoolArrayValue.Values))
- case *pb.Object_Float64ArrayValue:
- return printArray(reflect.ValueOf(x.Float64ArrayValue.Values))
- case *pb.Object_Float32ArrayValue:
- return printArray(reflect.ValueOf(x.Float32ArrayValue.Values))
- }
-
- // Should not happen, but tolerate.
- return fmt.Sprintf("(unknown proto type: %T)", object.GetValue()), true
-}
-
-// PrettyPrint reads the state stream from r, and pretty prints to w.
-func PrettyPrint(w io.Writer, r io.Reader, html bool) error {
- var (
- // current graph ID.
- graph uint64
-
- // current object ID.
- id uint64
- )
-
- if html {
- fmt.Fprintf(w, "<pre>")
- defer fmt.Fprintf(w, "</pre>")
- }
-
- for {
- // Find the first object to begin generation.
- length, object, err := ReadHeader(r)
- if err == io.EOF {
- // Nothing else to do.
- break
- } else if err != nil {
- return err
- }
- if !object {
- // Increment the graph number & reset the ID.
- graph++
- id = 0
- if length > 0 {
- fmt.Fprintf(w, "(%d bytes non-object data)\n", length)
- io.Copy(ioutil.Discard, &io.LimitedReader{
- R: r,
- N: int64(length),
- })
- }
- continue
- }
-
- // Read & unmarshal the object.
- buf := make([]byte, length)
- for done := 0; done < len(buf); {
- n, err := r.Read(buf[done:])
- done += n
- if n == 0 && err != nil {
- return err
- }
- }
- obj := new(pb.Object)
- if err := proto.Unmarshal(buf, obj); err != nil {
- return err
- }
-
- id++ // First object must be one.
- str, _ := format(graph, 0, obj, html)
- tag := fmt.Sprintf("g%dr%d", graph, id)
- if html {
- tag = fmt.Sprintf("<a name=%s>%s</a>", tag, tag)
- }
- if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func printArray(s reflect.Value) (string, bool) {
- zero := reflect.Zero(s.Type().Elem()).Interface()
- z := "0"
- switch s.Type().Elem().Kind() {
- case reflect.Bool:
- z = "false"
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
- case reflect.Float32, reflect.Float64:
- default:
- return fmt.Sprintf("unexpected non-primitive type array: %#v", s.Interface()), true
- }
-
- zeros := 0
- items := make([]string, 0, s.Len())
- for i := 0; i <= s.Len(); i++ {
- if i < s.Len() && reflect.DeepEqual(s.Index(i).Interface(), zero) {
- zeros++
- continue
- }
- if zeros > 0 {
- if zeros <= 4 {
- for ; zeros > 0; zeros-- {
- items = append(items, z)
- }
- } else {
- items = append(items, fmt.Sprintf("(%d %ss)", zeros, z))
- zeros = 0
- }
- }
- if i < s.Len() {
- items = append(items, fmt.Sprintf("%v", s.Index(i).Interface()))
- }
- }
- return "[" + strings.Join(items, ",") + "]", zeros < s.Len()
-}
diff --git a/pkg/state/state.go b/pkg/state/state.go
index 03ae2dbb0..acb629969 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -31,210 +31,226 @@
// Uint64 default
// Float32 default
// Float64 default
-// Complex64 custom
-// Complex128 custom
+// Complex64 default
+// Complex128 default
// Array default
// Chan custom
// Func custom
-// Interface custom
-// Map default (*)
+// Interface default
+// Map default
// Ptr default
// Slice default
// String default
-// Struct custom
+// Struct custom (*) Unless zero-sized.
// UnsafePointer custom
//
-// (*) Maps are treated as value types by this package, even if they are
-// pointers internally. If you want to save two independent references
-// to the same map value, you must explicitly use a pointer to a map.
+// See README.md for an overview of how encoding and decoding works.
package state
import (
"context"
"fmt"
- "io"
"reflect"
"runtime"
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
+ "gvisor.dev/gvisor/pkg/state/wire"
)
+// objectID is a unique identifier assigned to each object to be serialized.
+// Each instance of an object is considered separately, i.e. if there are two
+// objects of the same type in the object graph being serialized, they'll be
+// assigned unique objectIDs.
+type objectID uint32
+
+// typeID is the identifier for a type. Types are serialized and tracked
+// alongside objects in order to avoid the overhead of encoding field names in
+// all objects.
+type typeID uint32
+
// ErrState is returned when an error is encountered during encode/decode.
type ErrState struct {
// err is the underlying error.
err error
- // path is the visit path from root to the current object.
- path string
-
// trace is the stack trace.
trace string
}
// Error returns a sensible description of the state error.
func (e *ErrState) Error() string {
- return fmt.Sprintf("%v:\nstate path: %s\n%s", e.err, e.path, e.trace)
+ return fmt.Sprintf("%v:\n%s", e.err, e.trace)
}
-// UnwrapErrState returns the underlying error in ErrState.
-//
-// If err is not *ErrState, err is returned directly.
-func UnwrapErrState(err error) error {
- if e, ok := err.(*ErrState); ok {
- return e.err
- }
- return err
+// Unwrap implements standard unwrapping.
+func (e *ErrState) Unwrap() error {
+ return e.err
}
// Save saves the given object state.
-func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error {
+func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) {
// Create the encoding state.
- es := &encodeState{
- ctx: ctx,
- idsByObject: make(map[uintptr]uint64),
- w: w,
- stats: stats,
+ es := encodeState{
+ ctx: ctx,
+ w: w,
+ types: makeTypeEncodeDatabase(),
+ zeroValues: make(map[reflect.Type]*objectEncodeState),
}
// Perform the encoding.
- return es.safely(func() {
- es.Serialize(reflect.ValueOf(rootPtr).Elem())
+ err := safely(func() {
+ es.Save(reflect.ValueOf(rootPtr).Elem())
})
+ return es.stats, err
}
// Load loads a checkpoint.
-func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error {
+func Load(ctx context.Context, r wire.Reader, rootPtr interface{}) (Stats, error) {
// Create the decoding state.
- ds := &decodeState{
- ctx: ctx,
- objectsByID: make(map[uint64]*objectState),
- deferred: make(map[uint64]*pb.Object),
- r: r,
- stats: stats,
+ ds := decodeState{
+ ctx: ctx,
+ r: r,
+ types: makeTypeDecodeDatabase(),
+ deferred: make(map[objectID]wire.Object),
}
// Attempt our decode.
- return ds.safely(func() {
- ds.Deserialize(reflect.ValueOf(rootPtr).Elem())
+ err := safely(func() {
+ ds.Load(reflect.ValueOf(rootPtr).Elem())
})
+ return ds.stats, err
}
-// Fns are the state dispatch functions.
-type Fns struct {
- // Save is a function like Save(concreteType, Map).
- Save interface{}
-
- // Load is a function like Load(concreteType, Map).
- Load interface{}
+// Sink is used for Type.StateSave.
+type Sink struct {
+ internal objectEncoder
}
-// Save executes the save function.
-func (fns *Fns) invokeSave(obj reflect.Value, m Map) {
- reflect.ValueOf(fns.Save).Call([]reflect.Value{obj, reflect.ValueOf(m)})
+// Save adds the given object to the map.
+//
+// You should pass always pointers to the object you are saving. For example:
+//
+// type X struct {
+// A int
+// B *int
+// }
+//
+// func (x *X) StateTypeInfo(m Sink) state.TypeInfo {
+// return state.TypeInfo{
+// Name: "pkg.X",
+// Fields: []string{
+// "A",
+// "B",
+// },
+// }
+// }
+//
+// func (x *X) StateSave(m Sink) {
+// m.Save(0, &x.A) // Field is A.
+// m.Save(1, &x.B) // Field is B.
+// }
+//
+// func (x *X) StateLoad(m Source) {
+// m.Load(0, &x.A) // Field is A.
+// m.Load(1, &x.B) // Field is B.
+// }
+func (s Sink) Save(slot int, objPtr interface{}) {
+ s.internal.save(slot, reflect.ValueOf(objPtr).Elem())
}
-// Load executes the load function.
-func (fns *Fns) invokeLoad(obj reflect.Value, m Map) {
- reflect.ValueOf(fns.Load).Call([]reflect.Value{obj, reflect.ValueOf(m)})
+// SaveValue adds the given object value to the map.
+//
+// This should be used for values where pointers are not available, or casts
+// are required during Save/Load.
+//
+// For example, if we want to cast external package type P.Foo to int64:
+//
+// func (x *X) StateSave(m Sink) {
+// m.SaveValue(0, "A", int64(x.A))
+// }
+//
+// func (x *X) StateLoad(m Source) {
+// m.LoadValue(0, new(int64), func(x interface{}) {
+// x.A = P.Foo(x.(int64))
+// })
+// }
+func (s Sink) SaveValue(slot int, obj interface{}) {
+ s.internal.save(slot, reflect.ValueOf(obj))
}
-// validateStateFn ensures types are correct.
-func validateStateFn(fn interface{}, typ reflect.Type) bool {
- fnTyp := reflect.TypeOf(fn)
- if fnTyp.Kind() != reflect.Func {
- return false
- }
- if fnTyp.NumIn() != 2 {
- return false
- }
- if fnTyp.NumOut() != 0 {
- return false
- }
- if fnTyp.In(0) != typ {
- return false
- }
- if fnTyp.In(1) != reflect.TypeOf(Map{}) {
- return false
- }
- return true
+// Context returns the context object provided at save time.
+func (s Sink) Context() context.Context {
+ return s.internal.es.ctx
}
-// Validate validates all state functions.
-func (fns *Fns) Validate(typ reflect.Type) bool {
- return validateStateFn(fns.Save, typ) && validateStateFn(fns.Load, typ)
+// Type is an interface that must be implemented by Struct objects. This allows
+// these objects to be serialized while minimizing runtime reflection required.
+//
+// All these methods can be automatically generated by the go_statify tool.
+type Type interface {
+ // StateTypeName returns the type's name.
+ //
+ // This is used for matching type information during encoding and
+ // decoding, as well as dynamic interface dispatch. This should be
+ // globally unique.
+ StateTypeName() string
+
+ // StateFields returns information about the type.
+ //
+ // Fields is the set of fields for the object. Calls to Sink.Save and
+ // Source.Load must be made in-order with respect to these fields.
+ //
+ // This will be called at most once per serialization.
+ StateFields() []string
}
-type typeDatabase struct {
- // nameToType is a forward lookup table.
- nameToType map[string]reflect.Type
-
- // typeToName is the reverse lookup table.
- typeToName map[reflect.Type]string
+// SaverLoader must be implemented by struct types.
+type SaverLoader interface {
+ // StateSave saves the state of the object to the given Map.
+ StateSave(Sink)
- // typeToFns is the function lookup table.
- typeToFns map[reflect.Type]Fns
+ // StateLoad loads the state of the object.
+ StateLoad(Source)
}
-// registeredTypes is a database used for SaveInterface and LoadInterface.
-var registeredTypes = typeDatabase{
- nameToType: make(map[string]reflect.Type),
- typeToName: make(map[reflect.Type]string),
- typeToFns: make(map[reflect.Type]Fns),
+// Source is used for Type.StateLoad.
+type Source struct {
+ internal objectDecoder
}
-// register registers a type under the given name. This will generally be
-// called via init() methods, and therefore uses panic to propagate errors.
-func (t *typeDatabase) register(name string, typ reflect.Type, fns Fns) {
- // We can't allow name collisions.
- if ot, ok := t.nameToType[name]; ok {
- panic(fmt.Sprintf("type %q can't use name %q, already in use by type %q", typ.Name(), name, ot.Name()))
- }
-
- // Or multiple registrations.
- if on, ok := t.typeToName[typ]; ok {
- panic(fmt.Sprintf("type %q can't be registered as %q, already registered as %q", typ.Name(), name, on))
- }
-
- t.nameToType[name] = typ
- t.typeToName[typ] = name
- t.typeToFns[typ] = fns
+// Load loads the given object passed as a pointer..
+//
+// See Sink.Save for an example.
+func (s Source) Load(slot int, objPtr interface{}) {
+ s.internal.load(slot, reflect.ValueOf(objPtr), false, nil)
}
-// lookupType finds a type given a name.
-func (t *typeDatabase) lookupType(name string) (reflect.Type, bool) {
- typ, ok := t.nameToType[name]
- return typ, ok
+// LoadWait loads the given objects from the map, and marks it as requiring all
+// AfterLoad executions to complete prior to running this object's AfterLoad.
+//
+// See Sink.Save for an example.
+func (s Source) LoadWait(slot int, objPtr interface{}) {
+ s.internal.load(slot, reflect.ValueOf(objPtr), true, nil)
}
-// lookupName finds a name given a type.
-func (t *typeDatabase) lookupName(typ reflect.Type) (string, bool) {
- name, ok := t.typeToName[typ]
- return name, ok
+// LoadValue loads the given object value from the map.
+//
+// See Sink.SaveValue for an example.
+func (s Source) LoadValue(slot int, objPtr interface{}, fn func(interface{})) {
+ o := reflect.ValueOf(objPtr)
+ s.internal.load(slot, o, true, func() { fn(o.Elem().Interface()) })
}
-// lookupFns finds functions given a type.
-func (t *typeDatabase) lookupFns(typ reflect.Type) (Fns, bool) {
- fns, ok := t.typeToFns[typ]
- return fns, ok
+// AfterLoad schedules a function execution when all objects have been
+// allocated and their automated loading and customized load logic have been
+// executed. fn will not be executed until all of current object's
+// dependencies' AfterLoad() logic, if exist, have been executed.
+func (s Source) AfterLoad(fn func()) {
+ s.internal.afterLoad(fn)
}
-// Register must be called for any interface implementation types that
-// implements Loader.
-//
-// Register should be called either immediately after startup or via init()
-// methods. Double registration of either names or types will result in a panic.
-//
-// No synchronization is provided; this should only be called in init.
-//
-// Example usage:
-//
-// state.Register("Foo", (*Foo)(nil), state.Fns{
-// Save: (*Foo).Save,
-// Load: (*Foo).Load,
-// })
-//
-func Register(name string, instance interface{}, fns Fns) {
- registeredTypes.register(name, reflect.TypeOf(instance), fns)
+// Context returns the context object provided at load time.
+func (s Source) Context() context.Context {
+ return s.internal.ds.ctx
}
// IsZeroValue checks if the given value is the zero value.
@@ -244,72 +260,14 @@ func IsZeroValue(val interface{}) bool {
return val == nil || reflect.ValueOf(val).Elem().IsZero()
}
-// step captures one encoding / decoding step. On each step, there is up to one
-// choice made, which is captured by non-nil param. We intentionally do not
-// eagerly create the final path string, as that will only be needed upon panic.
-type step struct {
- // dereference indicate if the current object is obtained by
- // dereferencing a pointer.
- dereference bool
-
- // format is the formatting string that takes param below, if
- // non-nil. For example, in array indexing case, we have "[%d]".
- format string
-
- // param stores the choice made at the current encoding / decoding step.
- // For eaxmple, in array indexing case, param stores the index. When no
- // choice is made, e.g. dereference, param should be nil.
- param interface{}
-}
-
-// recoverable is the state encoding / decoding panic recovery facility. It is
-// also used to store encoding / decoding steps as well as the reference to the
-// original queued object from which the current object is dispatched. The
-// complete encoding / decoding path is synthesised from the steps in all queued
-// objects leading to the current object.
-type recoverable struct {
- from *recoverable
- steps []step
+// Failf is a wrapper around panic that should be used to generate errors that
+// can be caught during saving and loading.
+func Failf(fmtStr string, v ...interface{}) {
+ panic(fmt.Errorf(fmtStr, v...))
}
-// push enters a new context level.
-func (sr *recoverable) push(dereference bool, format string, param interface{}) {
- sr.steps = append(sr.steps, step{dereference, format, param})
-}
-
-// pop exits the current context level.
-func (sr *recoverable) pop() {
- if len(sr.steps) <= 1 {
- return
- }
- sr.steps = sr.steps[:len(sr.steps)-1]
-}
-
-// path returns the complete encoding / decoding path from root. This is only
-// called upon panic.
-func (sr *recoverable) path() string {
- if sr.from == nil {
- return "root"
- }
- p := sr.from.path()
- for _, s := range sr.steps {
- if s.dereference {
- p = fmt.Sprintf("*(%s)", p)
- }
- if s.param == nil {
- p += s.format
- } else {
- p += fmt.Sprintf(s.format, s.param)
- }
- }
- return p
-}
-
-func (sr *recoverable) copy() recoverable {
- return recoverable{from: sr.from, steps: append([]step(nil), sr.steps...)}
-}
-
-// safely executes the given function, catching a panic and unpacking as an error.
+// safely executes the given function, catching a panic and unpacking as an
+// error.
//
// The error flow through the state package uses panic and recover. There are
// two important reasons for this:
@@ -323,9 +281,15 @@ func (sr *recoverable) copy() recoverable {
// method doesn't add a lot of value. If there are specific error conditions
// that you'd like to handle, you should add appropriate functionality to
// objects themselves prior to calling Save() and Load().
-func (sr *recoverable) safely(fn func()) (err error) {
+func safely(fn func()) (err error) {
defer func() {
if r := recover(); r != nil {
+ if es, ok := r.(*ErrState); ok {
+ err = es // Propagate.
+ return
+ }
+
+ // Build a new state error.
es := new(ErrState)
if e, ok := r.(error); ok {
es.err = e
@@ -333,8 +297,6 @@ func (sr *recoverable) safely(fn func()) (err error) {
es.err = fmt.Errorf("%v", r)
}
- es.path = sr.path()
-
// Make a stack. We don't know how big it will be ahead
// of time, but want to make sure we get the whole
// thing. So we just do a stupid brute force approach.
diff --git a/pkg/state/state_norace.go b/pkg/state/state_norace.go
new file mode 100644
index 000000000..4281aed6d
--- /dev/null
+++ b/pkg/state/state_norace.go
@@ -0,0 +1,19 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !race
+
+package state
+
+var raceEnabled = false
diff --git a/pkg/state/state_race.go b/pkg/state/state_race.go
new file mode 100644
index 000000000..8232981ce
--- /dev/null
+++ b/pkg/state/state_race.go
@@ -0,0 +1,19 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build race
+
+package state
+
+var raceEnabled = true
diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go
deleted file mode 100644
index d7221e9e8..000000000
--- a/pkg/state/state_test.go
+++ /dev/null
@@ -1,721 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package state
-
-import (
- "bytes"
- "context"
- "io/ioutil"
- "math"
- "reflect"
- "testing"
-)
-
-// TestCase is used to define a single success/failure testcase of
-// serialization of a set of objects.
-type TestCase struct {
- // Name is the name of the test case.
- Name string
-
- // Objects is the list of values to serialize.
- Objects []interface{}
-
- // Fail is whether the test case is supposed to fail or not.
- Fail bool
-}
-
-// runTest runs all testcases.
-func runTest(t *testing.T, tests []TestCase) {
- for _, test := range tests {
- t.Logf("TEST %s:", test.Name)
- for i, root := range test.Objects {
- t.Logf(" case#%d: %#v", i, root)
-
- // Save the passed object.
- saveBuffer := &bytes.Buffer{}
- saveObjectPtr := reflect.New(reflect.TypeOf(root))
- saveObjectPtr.Elem().Set(reflect.ValueOf(root))
- if err := Save(context.Background(), saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail {
- t.Errorf(" FAIL: Save failed unexpectedly: %v", err)
- continue
- } else if err != nil {
- t.Logf(" PASS: Save failed as expected: %v", err)
- continue
- }
-
- // Load a new copy of the object.
- loadObjectPtr := reflect.New(reflect.TypeOf(root))
- if err := Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail {
- t.Errorf(" FAIL: Load failed unexpectedly: %v", err)
- continue
- } else if err != nil {
- t.Logf(" PASS: Load failed as expected: %v", err)
- continue
- }
-
- // Compare the values.
- loadedValue := loadObjectPtr.Elem().Interface()
- if eq := reflect.DeepEqual(root, loadedValue); !eq && !test.Fail {
- t.Errorf(" FAIL: Objects differs; got %#v", loadedValue)
- continue
- } else if !eq {
- t.Logf(" PASS: Object different as expected.")
- continue
- }
-
- // Everything went okay. Is that good?
- if test.Fail {
- t.Errorf(" FAIL: Unexpected success.")
- } else {
- t.Logf(" PASS: Success.")
- }
- }
- }
-}
-
-// dumbStruct is a struct which does not implement the loader/saver interface.
-// We expect that serialization of this struct will fail.
-type dumbStruct struct {
- A int
- B int
-}
-
-// smartStruct is a struct which does implement the loader/saver interface.
-// We expect that serialization of this struct will succeed.
-type smartStruct struct {
- A int
- B int
-}
-
-func (s *smartStruct) save(m Map) {
- m.Save("A", &s.A)
- m.Save("B", &s.B)
-}
-
-func (s *smartStruct) load(m Map) {
- m.Load("A", &s.A)
- m.Load("B", &s.B)
-}
-
-// valueLoadStruct uses a value load.
-type valueLoadStruct struct {
- v int
-}
-
-func (v *valueLoadStruct) save(m Map) {
- m.SaveValue("v", v.v)
-}
-
-func (v *valueLoadStruct) load(m Map) {
- m.LoadValue("v", new(int), func(value interface{}) {
- v.v = value.(int)
- })
-}
-
-// afterLoadStruct has an AfterLoad function.
-type afterLoadStruct struct {
- v int
-}
-
-func (a *afterLoadStruct) save(m Map) {
-}
-
-func (a *afterLoadStruct) load(m Map) {
- m.AfterLoad(func() {
- a.v++
- })
-}
-
-// genericContainer is a generic dispatcher.
-type genericContainer struct {
- v interface{}
-}
-
-func (g *genericContainer) save(m Map) {
- m.Save("v", &g.v)
-}
-
-func (g *genericContainer) load(m Map) {
- m.Load("v", &g.v)
-}
-
-// sliceContainer is a generic slice.
-type sliceContainer struct {
- v []interface{}
-}
-
-func (s *sliceContainer) save(m Map) {
- m.Save("v", &s.v)
-}
-
-func (s *sliceContainer) load(m Map) {
- m.Load("v", &s.v)
-}
-
-// mapContainer is a generic map.
-type mapContainer struct {
- v map[int]interface{}
-}
-
-func (mc *mapContainer) save(m Map) {
- m.Save("v", &mc.v)
-}
-
-func (mc *mapContainer) load(m Map) {
- // Some of the test cases below assume legacy behavior wherein maps
- // will automatically inherit dependencies.
- m.LoadWait("v", &mc.v)
-}
-
-// dumbMap is a map which does not implement the loader/saver interface.
-// Serialization of this map will default to the standard encode/decode logic.
-type dumbMap map[string]int
-
-// pointerStruct contains various pointers, shared and non-shared, and pointers
-// to pointers. We expect that serialization will respect the structure.
-type pointerStruct struct {
- A *int
- B *int
- C *int
- D *int
-
- AA **int
- BB **int
-}
-
-func (p *pointerStruct) save(m Map) {
- m.Save("A", &p.A)
- m.Save("B", &p.B)
- m.Save("C", &p.C)
- m.Save("D", &p.D)
- m.Save("AA", &p.AA)
- m.Save("BB", &p.BB)
-}
-
-func (p *pointerStruct) load(m Map) {
- m.Load("A", &p.A)
- m.Load("B", &p.B)
- m.Load("C", &p.C)
- m.Load("D", &p.D)
- m.Load("AA", &p.AA)
- m.Load("BB", &p.BB)
-}
-
-// testInterface is a trivial interface example.
-type testInterface interface {
- Foo()
-}
-
-// testImpl is a trivial implementation of testInterface.
-type testImpl struct {
-}
-
-// Foo satisfies testInterface.
-func (t *testImpl) Foo() {
-}
-
-// testImpl is trivially serializable.
-func (t *testImpl) save(m Map) {
-}
-
-// testImpl is trivially serializable.
-func (t *testImpl) load(m Map) {
-}
-
-// testI demonstrates interface dispatching.
-type testI struct {
- I testInterface
-}
-
-func (t *testI) save(m Map) {
- m.Save("I", &t.I)
-}
-
-func (t *testI) load(m Map) {
- m.Load("I", &t.I)
-}
-
-// cycleStruct is used to implement basic cycles.
-type cycleStruct struct {
- c *cycleStruct
-}
-
-func (c *cycleStruct) save(m Map) {
- m.Save("c", &c.c)
-}
-
-func (c *cycleStruct) load(m Map) {
- m.Load("c", &c.c)
-}
-
-// badCycleStruct actually has deadlocking dependencies.
-//
-// This should pass if b.b = {nil|b} and fail otherwise.
-type badCycleStruct struct {
- b *badCycleStruct
-}
-
-func (b *badCycleStruct) save(m Map) {
- m.Save("b", &b.b)
-}
-
-func (b *badCycleStruct) load(m Map) {
- m.LoadWait("b", &b.b)
- m.AfterLoad(func() {
- // This is not executable, since AfterLoad requires that the
- // object and all dependencies are complete. This should cause
- // a deadlock error during load.
- })
-}
-
-// emptyStructPointer points to an empty struct.
-type emptyStructPointer struct {
- nothing *struct{}
-}
-
-func (e *emptyStructPointer) save(m Map) {
- m.Save("nothing", &e.nothing)
-}
-
-func (e *emptyStructPointer) load(m Map) {
- m.Load("nothing", &e.nothing)
-}
-
-// truncateInteger truncates an integer.
-type truncateInteger struct {
- v int64
- v2 int32
-}
-
-func (t *truncateInteger) save(m Map) {
- t.v2 = int32(t.v)
- m.Save("v", &t.v)
-}
-
-func (t *truncateInteger) load(m Map) {
- m.Load("v", &t.v2)
- t.v = int64(t.v2)
-}
-
-// truncateUnsignedInteger truncates an unsigned integer.
-type truncateUnsignedInteger struct {
- v uint64
- v2 uint32
-}
-
-func (t *truncateUnsignedInteger) save(m Map) {
- t.v2 = uint32(t.v)
- m.Save("v", &t.v)
-}
-
-func (t *truncateUnsignedInteger) load(m Map) {
- m.Load("v", &t.v2)
- t.v = uint64(t.v2)
-}
-
-// truncateFloat truncates a floating point number.
-type truncateFloat struct {
- v float64
- v2 float32
-}
-
-func (t *truncateFloat) save(m Map) {
- t.v2 = float32(t.v)
- m.Save("v", &t.v)
-}
-
-func (t *truncateFloat) load(m Map) {
- m.Load("v", &t.v2)
- t.v = float64(t.v2)
-}
-
-func TestTypes(t *testing.T) {
- // x and y are basic integers, while xp points to x.
- x := 1
- y := 2
- xp := &x
-
- // cs is a single object cycle.
- cs := cycleStruct{nil}
- cs.c = &cs
-
- // cs1 and cs2 are in a two object cycle.
- cs1 := cycleStruct{nil}
- cs2 := cycleStruct{nil}
- cs1.c = &cs2
- cs2.c = &cs1
-
- // bs is a single object cycle.
- bs := badCycleStruct{nil}
- bs.b = &bs
-
- // bs2 and bs2 are in a deadlocking cycle.
- bs1 := badCycleStruct{nil}
- bs2 := badCycleStruct{nil}
- bs1.b = &bs2
- bs2.b = &bs1
-
- // regular nils.
- var (
- nilmap dumbMap
- nilslice []byte
- )
-
- // embed points to embedded fields.
- embed1 := pointerStruct{}
- embed1.AA = &embed1.A
- embed2 := pointerStruct{}
- embed2.BB = &embed2.B
-
- // es1 contains two structs pointing to the same empty struct.
- es := emptyStructPointer{new(struct{})}
- es1 := []emptyStructPointer{es, es}
-
- tests := []TestCase{
- {
- Name: "bool",
- Objects: []interface{}{
- true,
- false,
- },
- },
- {
- Name: "integers",
- Objects: []interface{}{
- int(0),
- int(1),
- int(-1),
- int8(0),
- int8(1),
- int8(-1),
- int16(0),
- int16(1),
- int16(-1),
- int32(0),
- int32(1),
- int32(-1),
- int64(0),
- int64(1),
- int64(-1),
- },
- },
- {
- Name: "unsigned integers",
- Objects: []interface{}{
- uint(0),
- uint(1),
- uint8(0),
- uint8(1),
- uint16(0),
- uint16(1),
- uint32(1),
- uint64(0),
- uint64(1),
- },
- },
- {
- Name: "strings",
- Objects: []interface{}{
- "",
- "foo",
- "bar",
- "\xa0",
- },
- },
- {
- Name: "slices",
- Objects: []interface{}{
- []int{-1, 0, 1},
- []*int{&x, &x, &x},
- []int{1, 2, 3}[0:1],
- []int{1, 2, 3}[1:2],
- make([]byte, 32),
- make([]byte, 32)[:16],
- make([]byte, 32)[:16:20],
- nilslice,
- },
- },
- {
- Name: "arrays",
- Objects: []interface{}{
- &[1048576]bool{false, true, false, true},
- &[1048576]uint8{0, 1, 2, 3},
- &[1048576]byte{0, 1, 2, 3},
- &[1048576]uint16{0, 1, 2, 3},
- &[1048576]uint{0, 1, 2, 3},
- &[1048576]uint32{0, 1, 2, 3},
- &[1048576]uint64{0, 1, 2, 3},
- &[1048576]uintptr{0, 1, 2, 3},
- &[1048576]int8{0, -1, -2, -3},
- &[1048576]int16{0, -1, -2, -3},
- &[1048576]int32{0, -1, -2, -3},
- &[1048576]int64{0, -1, -2, -3},
- &[1048576]float32{0, 1.1, 2.2, 3.3},
- &[1048576]float64{0, 1.1, 2.2, 3.3},
- },
- },
- {
- Name: "pointers",
- Objects: []interface{}{
- &pointerStruct{A: &x, B: &x, C: &y, D: &y, AA: &xp, BB: &xp},
- &pointerStruct{},
- },
- },
- {
- Name: "empty struct",
- Objects: []interface{}{
- struct{}{},
- },
- },
- {
- Name: "unenlightened structs",
- Objects: []interface{}{
- &dumbStruct{A: 1, B: 2},
- },
- Fail: true,
- },
- {
- Name: "enlightened structs",
- Objects: []interface{}{
- &smartStruct{A: 1, B: 2},
- },
- },
- {
- Name: "load-hooks",
- Objects: []interface{}{
- &afterLoadStruct{v: 1},
- &valueLoadStruct{v: 1},
- &genericContainer{v: &afterLoadStruct{v: 1}},
- &genericContainer{v: &valueLoadStruct{v: 1}},
- &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}},
- &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}},
- &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}},
- &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}},
- },
- },
- {
- Name: "maps",
- Objects: []interface{}{
- dumbMap{"a": -1, "b": 0, "c": 1},
- map[smartStruct]int{{}: 0, {A: 1}: 1},
- nilmap,
- &mapContainer{v: map[int]interface{}{0: &smartStruct{A: 1}}},
- },
- },
- {
- Name: "interfaces",
- Objects: []interface{}{
- &testI{&testImpl{}},
- &testI{nil},
- &testI{(*testImpl)(nil)},
- },
- },
- {
- Name: "unregistered-interfaces",
- Objects: []interface{}{
- &genericContainer{v: afterLoadStruct{v: 1}},
- &genericContainer{v: valueLoadStruct{v: 1}},
- &sliceContainer{v: []interface{}{afterLoadStruct{v: 1}}},
- &sliceContainer{v: []interface{}{valueLoadStruct{v: 1}}},
- &mapContainer{v: map[int]interface{}{0: afterLoadStruct{v: 1}}},
- &mapContainer{v: map[int]interface{}{0: valueLoadStruct{v: 1}}},
- },
- Fail: true,
- },
- {
- Name: "cycles",
- Objects: []interface{}{
- &cs,
- &cs1,
- &cycleStruct{&cs1},
- &cycleStruct{&cs},
- &badCycleStruct{nil},
- &bs,
- },
- },
- {
- Name: "deadlock",
- Objects: []interface{}{
- &bs1,
- },
- Fail: true,
- },
- {
- Name: "embed",
- Objects: []interface{}{
- &embed1,
- &embed2,
- },
- Fail: true,
- },
- {
- Name: "empty structs",
- Objects: []interface{}{
- new(struct{}),
- es,
- es1,
- },
- },
- {
- Name: "truncated okay",
- Objects: []interface{}{
- &truncateInteger{v: 1},
- &truncateUnsignedInteger{v: 1},
- &truncateFloat{v: 1.0},
- },
- },
- {
- Name: "truncated bad",
- Objects: []interface{}{
- &truncateInteger{v: math.MaxInt32 + 1},
- &truncateUnsignedInteger{v: math.MaxUint32 + 1},
- &truncateFloat{v: math.MaxFloat32 * 2},
- },
- Fail: true,
- },
- }
-
- runTest(t, tests)
-}
-
-// benchStruct is used for benchmarking.
-type benchStruct struct {
- b *benchStruct
-
- // Dummy data is included to ensure that these objects are large.
- // This is to detect possible regression when registering objects.
- _ [4096]byte
-}
-
-func (b *benchStruct) save(m Map) {
- m.Save("b", &b.b)
-}
-
-func (b *benchStruct) load(m Map) {
- m.LoadWait("b", &b.b)
- m.AfterLoad(b.afterLoad)
-}
-
-func (b *benchStruct) afterLoad() {
- // Do nothing, just force scheduling.
-}
-
-// buildObject builds a benchmark object.
-func buildObject(n int) (b *benchStruct) {
- for i := 0; i < n; i++ {
- b = &benchStruct{b: b}
- }
- return
-}
-
-func BenchmarkEncoding(b *testing.B) {
- b.StopTimer()
- bs := buildObject(b.N)
- var stats Stats
- b.StartTimer()
- if err := Save(context.Background(), ioutil.Discard, bs, &stats); err != nil {
- b.Errorf("save failed: %v", err)
- }
- b.StopTimer()
- if b.N > 1000 {
- b.Logf("breakdown (n=%d): %s", b.N, &stats)
- }
-}
-
-func BenchmarkDecoding(b *testing.B) {
- b.StopTimer()
- bs := buildObject(b.N)
- var newBS benchStruct
- buf := &bytes.Buffer{}
- if err := Save(context.Background(), buf, bs, nil); err != nil {
- b.Errorf("save failed: %v", err)
- }
- var stats Stats
- b.StartTimer()
- if err := Load(context.Background(), buf, &newBS, &stats); err != nil {
- b.Errorf("load failed: %v", err)
- }
- b.StopTimer()
- if b.N > 1000 {
- b.Logf("breakdown (n=%d): %s", b.N, &stats)
- }
-}
-
-func init() {
- Register("stateTest.smartStruct", (*smartStruct)(nil), Fns{
- Save: (*smartStruct).save,
- Load: (*smartStruct).load,
- })
- Register("stateTest.afterLoadStruct", (*afterLoadStruct)(nil), Fns{
- Save: (*afterLoadStruct).save,
- Load: (*afterLoadStruct).load,
- })
- Register("stateTest.valueLoadStruct", (*valueLoadStruct)(nil), Fns{
- Save: (*valueLoadStruct).save,
- Load: (*valueLoadStruct).load,
- })
- Register("stateTest.genericContainer", (*genericContainer)(nil), Fns{
- Save: (*genericContainer).save,
- Load: (*genericContainer).load,
- })
- Register("stateTest.sliceContainer", (*sliceContainer)(nil), Fns{
- Save: (*sliceContainer).save,
- Load: (*sliceContainer).load,
- })
- Register("stateTest.mapContainer", (*mapContainer)(nil), Fns{
- Save: (*mapContainer).save,
- Load: (*mapContainer).load,
- })
- Register("stateTest.pointerStruct", (*pointerStruct)(nil), Fns{
- Save: (*pointerStruct).save,
- Load: (*pointerStruct).load,
- })
- Register("stateTest.testImpl", (*testImpl)(nil), Fns{
- Save: (*testImpl).save,
- Load: (*testImpl).load,
- })
- Register("stateTest.testI", (*testI)(nil), Fns{
- Save: (*testI).save,
- Load: (*testI).load,
- })
- Register("stateTest.cycleStruct", (*cycleStruct)(nil), Fns{
- Save: (*cycleStruct).save,
- Load: (*cycleStruct).load,
- })
- Register("stateTest.badCycleStruct", (*badCycleStruct)(nil), Fns{
- Save: (*badCycleStruct).save,
- Load: (*badCycleStruct).load,
- })
- Register("stateTest.emptyStructPointer", (*emptyStructPointer)(nil), Fns{
- Save: (*emptyStructPointer).save,
- Load: (*emptyStructPointer).load,
- })
- Register("stateTest.truncateInteger", (*truncateInteger)(nil), Fns{
- Save: (*truncateInteger).save,
- Load: (*truncateInteger).load,
- })
- Register("stateTest.truncateUnsignedInteger", (*truncateUnsignedInteger)(nil), Fns{
- Save: (*truncateUnsignedInteger).save,
- Load: (*truncateUnsignedInteger).load,
- })
- Register("stateTest.truncateFloat", (*truncateFloat)(nil), Fns{
- Save: (*truncateFloat).save,
- Load: (*truncateFloat).load,
- })
- Register("stateTest.benchStruct", (*benchStruct)(nil), Fns{
- Save: (*benchStruct).save,
- Load: (*benchStruct).load,
- })
-}
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD
index e7581c09b..d6c89c7e9 100644
--- a/pkg/state/statefile/BUILD
+++ b/pkg/state/statefile/BUILD
@@ -9,6 +9,7 @@ go_library(
deps = [
"//pkg/binary",
"//pkg/compressio",
+ "//pkg/state/wire",
],
)
diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go
index c0f4c4954..bdfb800fb 100644
--- a/pkg/state/statefile/statefile.go
+++ b/pkg/state/statefile/statefile.go
@@ -57,6 +57,7 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/compressio"
+ "gvisor.dev/gvisor/pkg/state/wire"
)
// keySize is the AES-256 key length.
@@ -83,10 +84,16 @@ var ErrInvalidMetadataLength = fmt.Errorf("metadata length invalid, maximum size
// ErrMetadataInvalid is returned if passed metadata is invalid.
var ErrMetadataInvalid = fmt.Errorf("metadata invalid, can't start with _")
+// WriteCloser is an io.Closer and wire.Writer.
+type WriteCloser interface {
+ wire.Writer
+ io.Closer
+}
+
// NewWriter returns a state data writer for a statefile.
//
// Note that the returned WriteCloser must be closed.
-func NewWriter(w io.Writer, key []byte, metadata map[string]string) (io.WriteCloser, error) {
+func NewWriter(w io.Writer, key []byte, metadata map[string]string) (WriteCloser, error) {
if metadata == nil {
metadata = make(map[string]string)
}
@@ -215,7 +222,7 @@ func metadata(r io.Reader, h hash.Hash) (map[string]string, error) {
}
// NewReader returns a reader for a statefile.
-func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) {
+func NewReader(r io.Reader, key []byte) (wire.Reader, map[string]string, error) {
// Read the metadata with the hash.
h := hmac.New(sha256.New, key)
metadata, err := metadata(r, h)
@@ -224,9 +231,9 @@ func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) {
}
// Wrap in compression.
- rc, err := compressio.NewReader(r, key)
+ cr, err := compressio.NewReader(r, key)
if err != nil {
return nil, nil, err
}
- return rc, metadata, nil
+ return cr, metadata, nil
}
diff --git a/pkg/state/stats.go b/pkg/state/stats.go
index eb51cda47..eaec664a1 100644
--- a/pkg/state/stats.go
+++ b/pkg/state/stats.go
@@ -17,7 +17,6 @@ package state
import (
"bytes"
"fmt"
- "reflect"
"sort"
"time"
)
@@ -35,92 +34,81 @@ type statEntry struct {
// All exported receivers accept nil.
type Stats struct {
// byType contains a breakdown of time spent by type.
- byType map[reflect.Type]*statEntry
+ //
+ // This is indexed *directly* by typeID, including zero.
+ byType []statEntry
// stack contains objects in progress.
- stack []reflect.Type
+ stack []typeID
+
+ // names contains type names.
+ //
+ // This is also indexed *directly* by typeID, including zero, which we
+ // hard-code as "state.default". This is only resolved by calling fini
+ // on the stats object.
+ names []string
// last is the last start time.
last time.Time
}
-// sample adds the samples to the given object.
-func (s *Stats) sample(typ reflect.Type) {
- now := time.Now()
- s.byType[typ].total += now.Sub(s.last)
- s.last = now
+// init initializes statistics.
+func (s *Stats) init() {
+ s.last = time.Now()
+ s.stack = append(s.stack, 0)
}
-// Add adds a sample count.
-func (s *Stats) Add(obj reflect.Value) {
- if s == nil {
- return
- }
- if s.byType == nil {
- s.byType = make(map[reflect.Type]*statEntry)
- }
- typ := obj.Type()
- entry, ok := s.byType[typ]
- if !ok {
- entry = new(statEntry)
- s.byType[typ] = entry
+// fini finalizes statistics.
+func (s *Stats) fini(resolve func(id typeID) string) {
+ s.done()
+
+ // Resolve all type names.
+ s.names = make([]string, len(s.byType))
+ s.names[0] = "state.default" // See above.
+ for id := typeID(1); int(id) < len(s.names); id++ {
+ s.names[id] = resolve(id)
}
- entry.count++
}
-// Remove removes a sample count. It should only be called after a previous
-// Add().
-func (s *Stats) Remove(obj reflect.Value) {
- if s == nil {
- return
+// sample adds the samples to the given object.
+func (s *Stats) sample(id typeID) {
+ now := time.Now()
+ if len(s.byType) <= int(id) {
+ // Allocate all the missing entries in one fell swoop.
+ s.byType = append(s.byType, make([]statEntry, 1+int(id)-len(s.byType))...)
}
- typ := obj.Type()
- entry := s.byType[typ]
- entry.count--
+ s.byType[id].total += now.Sub(s.last)
+ s.last = now
}
-// Start starts a sample.
-func (s *Stats) Start(obj reflect.Value) {
- if s == nil {
- return
- }
- if len(s.stack) > 0 {
- last := s.stack[len(s.stack)-1]
- s.sample(last)
- } else {
- // First time sample.
- s.last = time.Now()
- }
- s.stack = append(s.stack, obj.Type())
+// start starts a sample.
+func (s *Stats) start(id typeID) {
+ last := s.stack[len(s.stack)-1]
+ s.sample(last)
+ s.stack = append(s.stack, id)
}
-// Done finishes the current sample.
-func (s *Stats) Done() {
- if s == nil {
- return
- }
+// done finishes the current sample.
+func (s *Stats) done() {
last := s.stack[len(s.stack)-1]
s.sample(last)
+ s.byType[last].count++
s.stack = s.stack[:len(s.stack)-1]
}
type sliceEntry struct {
- typ reflect.Type
+ name string
entry *statEntry
}
// String returns a table representation of the stats.
func (s *Stats) String() string {
- if s == nil || len(s.byType) == 0 {
- return "(no data)"
- }
-
// Build a list of stat entries.
ss := make([]sliceEntry, 0, len(s.byType))
- for typ, entry := range s.byType {
+ for id := 0; id < len(s.names); id++ {
ss = append(ss, sliceEntry{
- typ: typ,
- entry: entry,
+ name: s.names[id],
+ entry: &s.byType[id],
})
}
@@ -136,17 +124,22 @@ func (s *Stats) String() string {
total time.Duration
)
buf.WriteString("\n")
- buf.WriteString(fmt.Sprintf("%12s | %8s | %8s | %s\n", "total", "count", "per", "type"))
- buf.WriteString("-------------+----------+----------+-------------\n")
+ buf.WriteString(fmt.Sprintf("% 16s | % 8s | % 16s | %s\n", "total", "count", "per", "type"))
+ buf.WriteString("-----------------+----------+------------------+----------------\n")
for _, se := range ss {
+ if se.entry.count == 0 {
+ // Since we store all types linearly, we are not
+ // guaranteed that any entry actually has time.
+ continue
+ }
count += se.entry.count
total += se.entry.total
per := se.entry.total / time.Duration(se.entry.count)
- buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | %s\n",
- se.entry.total, se.entry.count, per, se.typ.String()))
+ buf.WriteString(fmt.Sprintf("% 16s | %8d | % 16s | %s\n",
+ se.entry.total, se.entry.count, per, se.name))
}
- buf.WriteString("-------------+----------+----------+-------------\n")
- buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | [all]",
+ buf.WriteString("-----------------+----------+------------------+----------------\n")
+ buf.WriteString(fmt.Sprintf("% 16s | % 8d | % 16s | [all]",
total, count, total/time.Duration(count)))
return string(buf.Bytes())
}
diff --git a/pkg/state/tests/BUILD b/pkg/state/tests/BUILD
new file mode 100644
index 000000000..9297cafbe
--- /dev/null
+++ b/pkg/state/tests/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tests",
+ srcs = [
+ "array.go",
+ "bench.go",
+ "integer.go",
+ "load.go",
+ "map.go",
+ "register.go",
+ "struct.go",
+ "tests.go",
+ ],
+ deps = [
+ "//pkg/state",
+ "//pkg/state/pretty",
+ ],
+)
+
+go_test(
+ name = "tests_test",
+ size = "small",
+ srcs = [
+ "array_test.go",
+ "bench_test.go",
+ "bool_test.go",
+ "float_test.go",
+ "integer_test.go",
+ "load_test.go",
+ "map_test.go",
+ "register_test.go",
+ "string_test.go",
+ "struct_test.go",
+ ],
+ library = ":tests",
+ deps = [
+ "//pkg/state",
+ "//pkg/state/wire",
+ ],
+)
diff --git a/pkg/state/tests/array.go b/pkg/state/tests/array.go
new file mode 100644
index 000000000..0972a80e7
--- /dev/null
+++ b/pkg/state/tests/array.go
@@ -0,0 +1,35 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type arrayContainer struct {
+ v [1]interface{}
+}
+
+// +stateify savable
+type arrayPtrContainer struct {
+ v *[1]interface{}
+}
+
+// +stateify savable
+type sliceContainer struct {
+ v []interface{}
+}
+
+// +stateify savable
+type slicePtrContainer struct {
+ v *[]interface{}
+}
diff --git a/pkg/state/tests/array_test.go b/pkg/state/tests/array_test.go
new file mode 100644
index 000000000..a347b2947
--- /dev/null
+++ b/pkg/state/tests/array_test.go
@@ -0,0 +1,134 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "reflect"
+ "testing"
+)
+
+var allArrayPrimitives = []interface{}{
+ [1]bool{},
+ [1]bool{true},
+ [2]bool{false, true},
+ [1]int{},
+ [1]int{1},
+ [2]int{0, 1},
+ [1]int8{},
+ [1]int8{1},
+ [2]int8{0, 1},
+ [1]int16{},
+ [1]int16{1},
+ [2]int16{0, 1},
+ [1]int32{},
+ [1]int32{1},
+ [2]int32{0, 1},
+ [1]int64{},
+ [1]int64{1},
+ [2]int64{0, 1},
+ [1]uint{},
+ [1]uint{1},
+ [2]uint{0, 1},
+ [1]uintptr{},
+ [1]uintptr{1},
+ [2]uintptr{0, 1},
+ [1]uint8{},
+ [1]uint8{1},
+ [2]uint8{0, 1},
+ [1]uint16{},
+ [1]uint16{1},
+ [2]uint16{0, 1},
+ [1]uint32{},
+ [1]uint32{1},
+ [2]uint32{0, 1},
+ [1]uint64{},
+ [1]uint64{1},
+ [2]uint64{0, 1},
+ [1]string{},
+ [1]string{""},
+ [1]string{nonEmptyString},
+ [2]string{"", nonEmptyString},
+}
+
+func TestArrayPrimitives(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(allArrayPrimitives))
+ runTestCases(t, false, "pointers", pointersTo(flatten(allArrayPrimitives)))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(allArrayPrimitives)))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allArrayPrimitives))))
+}
+
+func TestSlices(t *testing.T) {
+ var allSlices = flatten(
+ filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o)).Elem()
+ v.Set(reflect.ValueOf(o))
+ return v.Slice(0, v.Len()).Interface(), true
+ }),
+ filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o)).Elem()
+ v.Set(reflect.ValueOf(o))
+ if v.Len() == 0 {
+ // Return the pure "nil" value for the slice.
+ return reflect.New(v.Slice(0, 0).Type()).Elem().Interface(), true
+ }
+ return v.Slice(1, v.Len()).Interface(), true
+ }),
+ filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o)).Elem()
+ v.Set(reflect.ValueOf(o))
+ if v.Len() == 0 {
+ // Return the zero-valued slice.
+ return reflect.MakeSlice(v.Slice(0, 0).Type(), 0, 0).Interface(), true
+ }
+ return v.Slice(0, v.Len()-1).Interface(), true
+ }),
+ )
+ runTestCases(t, false, "plain", allSlices)
+ runTestCases(t, false, "pointers", pointersTo(allSlices))
+ runTestCases(t, false, "interfaces", interfacesTo(allSlices))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(allSlices)))
+}
+
+func TestArrayContainers(t *testing.T) {
+ var (
+ emptyArray [1]interface{}
+ fullArray [1]interface{}
+ )
+ fullArray[0] = &emptyArray
+ runTestCases(t, false, "", []interface{}{
+ arrayContainer{v: emptyArray},
+ arrayContainer{v: fullArray},
+ arrayPtrContainer{v: nil},
+ arrayPtrContainer{v: &emptyArray},
+ arrayPtrContainer{v: &fullArray},
+ })
+}
+
+func TestSliceContainers(t *testing.T) {
+ var (
+ nilSlice []interface{}
+ emptySlice = make([]interface{}, 0)
+ fullSlice = []interface{}{nil}
+ )
+ runTestCases(t, false, "", []interface{}{
+ sliceContainer{v: nilSlice},
+ sliceContainer{v: emptySlice},
+ sliceContainer{v: fullSlice},
+ slicePtrContainer{v: nil},
+ slicePtrContainer{v: &nilSlice},
+ slicePtrContainer{v: &emptySlice},
+ slicePtrContainer{v: &fullSlice},
+ })
+}
diff --git a/pkg/state/tests/bench.go b/pkg/state/tests/bench.go
new file mode 100644
index 000000000..40869cdfb
--- /dev/null
+++ b/pkg/state/tests/bench.go
@@ -0,0 +1,24 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type benchStruct struct {
+ B *benchStruct // Must be exported for gob.
+}
+
+func (b *benchStruct) afterLoad() {
+ // Do nothing, just force scheduling.
+}
diff --git a/pkg/state/tests/bench_test.go b/pkg/state/tests/bench_test.go
new file mode 100644
index 000000000..7e102c907
--- /dev/null
+++ b/pkg/state/tests/bench_test.go
@@ -0,0 +1,153 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "bytes"
+ "context"
+ "encoding/gob"
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// buildPtrObject builds a benchmark object.
+func buildPtrObject(n int) interface{} {
+ b := new(benchStruct)
+ for i := 0; i < n; i++ {
+ b = &benchStruct{B: b}
+ }
+ return b
+}
+
+// buildMapObject builds a benchmark object.
+func buildMapObject(n int) interface{} {
+ b := new(benchStruct)
+ m := make(map[int]*benchStruct)
+ for i := 0; i < n; i++ {
+ m[i] = b
+ }
+ return &m
+}
+
+// buildSliceObject builds a benchmark object.
+func buildSliceObject(n int) interface{} {
+ b := new(benchStruct)
+ s := make([]*benchStruct, 0, n)
+ for i := 0; i < n; i++ {
+ s = append(s, b)
+ }
+ return &s
+}
+
+var allObjects = map[string]struct {
+ New func(int) interface{}
+}{
+ "ptr": {
+ New: buildPtrObject,
+ },
+ "map": {
+ New: buildMapObject,
+ },
+ "slice": {
+ New: buildSliceObject,
+ },
+}
+
+func buildObjects(n int, fn func(int) interface{}) (iters int, v interface{}) {
+ // maxSize is the maximum size of an individual object below. For an N
+ // larger than this, we start to return multiple objects.
+ const maxSize = 1024
+ if n <= maxSize {
+ return 1, fn(n)
+ }
+ iters = (n + maxSize - 1) / maxSize
+ return iters, fn(maxSize)
+}
+
+// gobSave is a version of save using gob (no stats available).
+func gobSave(_ context.Context, w wire.Writer, v interface{}) (_ state.Stats, err error) {
+ enc := gob.NewEncoder(w)
+ err = enc.Encode(v)
+ return
+}
+
+// gobLoad is a version of load using gob (no stats available).
+func gobLoad(_ context.Context, r wire.Reader, v interface{}) (_ state.Stats, err error) {
+ dec := gob.NewDecoder(r)
+ err = dec.Decode(v)
+ return
+}
+
+var allAlgos = map[string]struct {
+ Save func(context.Context, wire.Writer, interface{}) (state.Stats, error)
+ Load func(context.Context, wire.Reader, interface{}) (state.Stats, error)
+ MaxPtr int
+}{
+ "state": {
+ Save: state.Save,
+ Load: state.Load,
+ },
+ "gob": {
+ Save: gobSave,
+ Load: gobLoad,
+ },
+}
+
+func BenchmarkEncoding(b *testing.B) {
+ for objName, objInfo := range allObjects {
+ for algoName, algoInfo := range allAlgos {
+ b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) {
+ b.StopTimer()
+ n, v := buildObjects(b.N, objInfo.New)
+ b.ReportAllocs()
+ b.StartTimer()
+ for i := 0; i < n; i++ {
+ if _, err := algoInfo.Save(context.Background(), discard{}, v); err != nil {
+ b.Errorf("save failed: %v", err)
+ }
+ }
+ b.StopTimer()
+ })
+ }
+ }
+}
+
+func BenchmarkDecoding(b *testing.B) {
+ for objName, objInfo := range allObjects {
+ for algoName, algoInfo := range allAlgos {
+ b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) {
+ b.StopTimer()
+ n, v := buildObjects(b.N, objInfo.New)
+ buf := new(bytes.Buffer)
+ if _, err := algoInfo.Save(context.Background(), buf, v); err != nil {
+ b.Errorf("save failed: %v", err)
+ }
+ b.ReportAllocs()
+ b.StartTimer()
+ var r bytes.Reader
+ for i := 0; i < n; i++ {
+ r.Reset(buf.Bytes())
+ if _, err := algoInfo.Load(context.Background(), &r, v); err != nil {
+ b.Errorf("load failed: %v", err)
+ }
+ }
+ b.StopTimer()
+ })
+ }
+ }
+}
diff --git a/pkg/state/tests/bool_test.go b/pkg/state/tests/bool_test.go
new file mode 100644
index 000000000..e17cfacf9
--- /dev/null
+++ b/pkg/state/tests/bool_test.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+)
+
+var allBools = []bool{
+ true,
+ false,
+}
+
+func TestBool(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(allBools))
+ runTestCases(t, false, "pointers", pointersTo(flatten(allBools)))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(allBools)))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allBools))))
+}
diff --git a/pkg/state/tests/float_test.go b/pkg/state/tests/float_test.go
new file mode 100644
index 000000000..3e89edd9c
--- /dev/null
+++ b/pkg/state/tests/float_test.go
@@ -0,0 +1,118 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "math"
+ "testing"
+)
+
+var safeFloat32s = []float32{
+ float32(0.0),
+ float32(1.0),
+ float32(-1.0),
+ float32(math.Inf(1)),
+ float32(math.Inf(-1)),
+}
+
+var allFloat32s = append(safeFloat32s, float32(math.NaN()))
+
+var safeFloat64s = []float64{
+ float64(0.0),
+ float64(1.0),
+ float64(-1.0),
+ math.Inf(1),
+ math.Inf(-1),
+}
+
+var allFloat64s = append(safeFloat64s, math.NaN())
+
+func TestFloat(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(
+ allFloat32s,
+ allFloat64s,
+ ))
+ // See checkEqual for why NaNs are missing.
+ runTestCases(t, false, "pointers", pointersTo(flatten(
+ safeFloat32s,
+ safeFloat64s,
+ )))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(
+ safeFloat32s,
+ safeFloat64s,
+ )))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(
+ safeFloat32s,
+ safeFloat64s,
+ ))))
+}
+
+const onlyDouble float64 = 1.0000000000000002
+
+func TestFloatTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingFloat32{save: onlyDouble},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingFloat32{save: 1.0},
+ })
+}
+
+var safeComplex64s = combine(safeFloat32s, safeFloat32s, func(i, j interface{}) interface{} {
+ return complex(i.(float32), j.(float32))
+})
+
+var allComplex64s = combine(allFloat32s, allFloat32s, func(i, j interface{}) interface{} {
+ return complex(i.(float32), j.(float32))
+})
+
+var safeComplex128s = combine(safeFloat64s, safeFloat64s, func(i, j interface{}) interface{} {
+ return complex(i.(float64), j.(float64))
+})
+
+var allComplex128s = combine(allFloat64s, allFloat64s, func(i, j interface{}) interface{} {
+ return complex(i.(float64), j.(float64))
+})
+
+func TestComplex(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(
+ allComplex64s,
+ allComplex128s,
+ ))
+ // See TestFloat; same issue.
+ runTestCases(t, false, "pointers", pointersTo(flatten(
+ safeComplex64s,
+ safeComplex128s,
+ )))
+ runTestCases(t, false, "interfacse", interfacesTo(flatten(
+ safeComplex64s,
+ safeComplex128s,
+ )))
+ runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(flatten(
+ safeComplex64s,
+ safeComplex128s,
+ ))))
+}
+
+func TestComplexTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingComplex64{save: complex(onlyDouble, onlyDouble)},
+ truncatingComplex64{save: complex(1.0, onlyDouble)},
+ truncatingComplex64{save: complex(onlyDouble, 1.0)},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingComplex64{save: complex(1.0, 1.0)},
+ })
+}
diff --git a/pkg/state/tests/integer.go b/pkg/state/tests/integer.go
new file mode 100644
index 000000000..ca403eed1
--- /dev/null
+++ b/pkg/state/tests/integer.go
@@ -0,0 +1,163 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+// +stateify type
+type truncatingUint8 struct {
+ save uint64
+ load uint8 `state:"nosave"`
+}
+
+func (t *truncatingUint8) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingUint8) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = uint64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingUint8)(nil)
+
+// +stateify type
+type truncatingUint16 struct {
+ save uint64
+ load uint16 `state:"nosave"`
+}
+
+func (t *truncatingUint16) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingUint16) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = uint64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingUint16)(nil)
+
+// +stateify type
+type truncatingUint32 struct {
+ save uint64
+ load uint32 `state:"nosave"`
+}
+
+func (t *truncatingUint32) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingUint32) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = uint64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingUint32)(nil)
+
+// +stateify type
+type truncatingInt8 struct {
+ save int64
+ load int8 `state:"nosave"`
+}
+
+func (t *truncatingInt8) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingInt8) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = int64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingInt8)(nil)
+
+// +stateify type
+type truncatingInt16 struct {
+ save int64
+ load int16 `state:"nosave"`
+}
+
+func (t *truncatingInt16) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingInt16) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = int64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingInt16)(nil)
+
+// +stateify type
+type truncatingInt32 struct {
+ save int64
+ load int32 `state:"nosave"`
+}
+
+func (t *truncatingInt32) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingInt32) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = int64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingInt32)(nil)
+
+// +stateify type
+type truncatingFloat32 struct {
+ save float64
+ load float32 `state:"nosave"`
+}
+
+func (t *truncatingFloat32) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingFloat32) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = float64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingFloat32)(nil)
+
+// +stateify type
+type truncatingComplex64 struct {
+ save complex128
+ load complex64 `state:"nosave"`
+}
+
+func (t *truncatingComplex64) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingComplex64) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = complex128(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingComplex64)(nil)
diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go
new file mode 100644
index 000000000..d3931c952
--- /dev/null
+++ b/pkg/state/tests/integer_test.go
@@ -0,0 +1,94 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "math"
+ "testing"
+)
+
+var (
+ allIntTs = []int{-1, 0, 1}
+ allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8}
+ allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16}
+ allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32}
+ allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64}
+ allUintTs = []uint{0, 1}
+ allUintptrs = []uintptr{0, 1, ^uintptr(0)}
+ allUint8s = []uint8{0, 1, math.MaxUint8}
+ allUint16s = []uint16{0, 1, math.MaxUint16}
+ allUint32s = []uint32{0, 1, math.MaxUint32}
+ allUint64s = []uint64{0, 1, math.MaxUint64}
+)
+
+var allInts = flatten(
+ allIntTs,
+ allInt8s,
+ allInt16s,
+ allInt32s,
+ allInt64s,
+)
+
+var allUints = flatten(
+ allUintTs,
+ allUintptrs,
+ allUint8s,
+ allUint16s,
+ allUint32s,
+ allUint64s,
+)
+
+func TestInt(t *testing.T) {
+ runTestCases(t, false, "plain", allInts)
+ runTestCases(t, false, "pointers", pointersTo(allInts))
+ runTestCases(t, false, "interfaces", interfacesTo(allInts))
+ runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allInts)))
+}
+
+func TestIntTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingInt8{save: math.MinInt8 - 1},
+ truncatingInt16{save: math.MinInt16 - 1},
+ truncatingInt32{save: math.MinInt32 - 1},
+ truncatingInt8{save: math.MaxInt8 + 1},
+ truncatingInt16{save: math.MaxInt16 + 1},
+ truncatingInt32{save: math.MaxInt32 + 1},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingInt8{save: 1},
+ truncatingInt16{save: 1},
+ truncatingInt32{save: 1},
+ })
+}
+
+func TestUint(t *testing.T) {
+ runTestCases(t, false, "plain", allUints)
+ runTestCases(t, false, "pointers", pointersTo(allUints))
+ runTestCases(t, false, "interfaces", interfacesTo(allUints))
+ runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allUints)))
+}
+
+func TestUintTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingUint8{save: math.MaxUint8 + 1},
+ truncatingUint16{save: math.MaxUint16 + 1},
+ truncatingUint32{save: math.MaxUint32 + 1},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingUint8{save: 1},
+ truncatingUint16{save: 1},
+ truncatingUint32{save: 1},
+ })
+}
diff --git a/pkg/state/tests/load.go b/pkg/state/tests/load.go
new file mode 100644
index 000000000..a8350c0f3
--- /dev/null
+++ b/pkg/state/tests/load.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type genericContainer struct {
+ v interface{}
+}
+
+// +stateify savable
+type afterLoadStruct struct {
+ v int `state:"nosave"`
+}
+
+func (a *afterLoadStruct) afterLoad() {
+ a.v++
+}
+
+// +stateify savable
+type valueLoadStruct struct {
+ v int `state:".(int64)"`
+}
+
+func (v *valueLoadStruct) saveV() int64 {
+ return int64(v.v) // Save as int64.
+}
+
+func (v *valueLoadStruct) loadV(value int64) {
+ v.v = int(value) // Load as int.
+}
+
+// +stateify savable
+type cycleStruct struct {
+ c *cycleStruct
+}
+
+// +stateify savable
+type badCycleStruct struct {
+ b *badCycleStruct `state:"wait"`
+}
+
+func (b *badCycleStruct) afterLoad() {
+ if b.b != b {
+ // This is not executable, since AfterLoad requires that the
+ // object and all dependencies are complete. This should cause
+ // a deadlock error during load.
+ panic("badCycleStruct.afterLoad called")
+ }
+}
diff --git a/pkg/state/tests/load_test.go b/pkg/state/tests/load_test.go
new file mode 100644
index 000000000..1e9794296
--- /dev/null
+++ b/pkg/state/tests/load_test.go
@@ -0,0 +1,70 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+)
+
+func TestLoadHooks(t *testing.T) {
+ runTestCases(t, false, "load-hooks", []interface{}{
+ &afterLoadStruct{v: 1},
+ &valueLoadStruct{v: 1},
+ &genericContainer{v: &afterLoadStruct{v: 1}},
+ &genericContainer{v: &valueLoadStruct{v: 1}},
+ &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}},
+ &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}},
+ &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}},
+ &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}},
+ })
+}
+
+func TestCycles(t *testing.T) {
+ // cs is a single object cycle.
+ cs := cycleStruct{nil}
+ cs.c = &cs
+
+ // cs1 and cs2 are in a two object cycle.
+ cs1 := cycleStruct{nil}
+ cs2 := cycleStruct{nil}
+ cs1.c = &cs2
+ cs2.c = &cs1
+
+ runTestCases(t, false, "cycles", []interface{}{
+ cs,
+ cs1,
+ })
+}
+
+func TestDeadlock(t *testing.T) {
+ // bs is a single object cycle. This does not cause deadlock because an
+ // object cannot wait for itself.
+ bs := badCycleStruct{nil}
+ bs.b = &bs
+
+ runTestCases(t, false, "self", []interface{}{
+ &bs,
+ })
+
+ // bs2 and bs2 are in a deadlocking cycle.
+ bs1 := badCycleStruct{nil}
+ bs2 := badCycleStruct{nil}
+ bs1.b = &bs2
+ bs2.b = &bs1
+
+ runTestCases(t, true, "deadlock", []interface{}{
+ &bs1,
+ })
+}
diff --git a/pkg/state/tests/map.go b/pkg/state/tests/map.go
new file mode 100644
index 000000000..db4e548f1
--- /dev/null
+++ b/pkg/state/tests/map.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type mapContainer struct {
+ v map[int]interface{}
+}
+
+// +stateify savable
+type mapPtrContainer struct {
+ v *map[int]interface{}
+}
+
+// +stateify savable
+type registeredMapStruct struct{}
diff --git a/pkg/state/tests/map_test.go b/pkg/state/tests/map_test.go
new file mode 100644
index 000000000..92bf0fc01
--- /dev/null
+++ b/pkg/state/tests/map_test.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "reflect"
+ "testing"
+)
+
+var allMapPrimitives = []interface{}{
+ bool(true),
+ int(1),
+ int8(1),
+ int16(1),
+ int32(1),
+ int64(1),
+ uint(1),
+ uintptr(1),
+ uint8(1),
+ uint16(1),
+ uint32(1),
+ uint64(1),
+ string(""),
+ registeredMapStruct{},
+}
+
+var allMapKeys = flatten(allMapPrimitives, pointersTo(allMapPrimitives))
+
+var allMapValues = flatten(allMapPrimitives, pointersTo(allMapPrimitives), interfacesTo(allMapPrimitives))
+
+var emptyMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} {
+ m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2)))
+ return m.Interface()
+})
+
+var fullMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} {
+ m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2)))
+ m.SetMapIndex(reflect.Zero(reflect.TypeOf(v1)), reflect.Zero(reflect.TypeOf(v2)))
+ return m.Interface()
+})
+
+func TestMapAliasing(t *testing.T) {
+ v := make(map[int]int)
+ ptrToV := &v
+ aliases := []map[int]int{v, v}
+ runTestCases(t, false, "", []interface{}{ptrToV, aliases})
+}
+
+func TestMapsEmpty(t *testing.T) {
+ runTestCases(t, false, "plain", emptyMaps)
+ runTestCases(t, false, "pointers", pointersTo(emptyMaps))
+ runTestCases(t, false, "interfaces", interfacesTo(emptyMaps))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(emptyMaps)))
+}
+
+func TestMapsFull(t *testing.T) {
+ runTestCases(t, false, "plain", fullMaps)
+ runTestCases(t, false, "pointers", pointersTo(fullMaps))
+ runTestCases(t, false, "interfaces", interfacesTo(fullMaps))
+ runTestCases(t, false, "interfacesToPointer", interfacesTo(pointersTo(fullMaps)))
+}
+
+func TestMapContainers(t *testing.T) {
+ var (
+ nilMap map[int]interface{}
+ emptyMap = make(map[int]interface{})
+ fullMap = map[int]interface{}{0: nil}
+ )
+ runTestCases(t, false, "", []interface{}{
+ mapContainer{v: nilMap},
+ mapContainer{v: emptyMap},
+ mapContainer{v: fullMap},
+ mapPtrContainer{v: nil},
+ mapPtrContainer{v: &nilMap},
+ mapPtrContainer{v: &emptyMap},
+ mapPtrContainer{v: &fullMap},
+ })
+}
diff --git a/pkg/state/tests/register.go b/pkg/state/tests/register.go
new file mode 100644
index 000000000..074d86315
--- /dev/null
+++ b/pkg/state/tests/register.go
@@ -0,0 +1,21 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type alreadyRegisteredStruct struct{}
+
+// +stateify savable
+type alreadyRegisteredOther int
diff --git a/pkg/state/tests/register_test.go b/pkg/state/tests/register_test.go
new file mode 100644
index 000000000..c829753cc
--- /dev/null
+++ b/pkg/state/tests/register_test.go
@@ -0,0 +1,167 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+// faker calls itself whatever is in the name field.
+type faker struct {
+ Name string
+ Fields []string
+}
+
+func (f *faker) StateTypeName() string {
+ return f.Name
+}
+
+func (f *faker) StateFields() []string {
+ return f.Fields
+}
+
+// fakerWithSaverLoader has all it needs.
+type fakerWithSaverLoader struct {
+ faker
+}
+
+func (f *fakerWithSaverLoader) StateSave(m state.Sink) {}
+
+func (f *fakerWithSaverLoader) StateLoad(m state.Source) {}
+
+// fakerOther calls itself .. uh, itself?
+type fakerOther string
+
+func (f *fakerOther) StateTypeName() string {
+ return string(*f)
+}
+
+func (f *fakerOther) StateFields() []string {
+ return nil
+}
+
+func newFakerOther(name string) *fakerOther {
+ f := fakerOther(name)
+ return &f
+}
+
+// fakerOtherBadFields returns non-nil fields.
+type fakerOtherBadFields string
+
+func (f *fakerOtherBadFields) StateTypeName() string {
+ return string(*f)
+}
+
+func (f *fakerOtherBadFields) StateFields() []string {
+ return []string{string(*f)}
+}
+
+func newFakerOtherBadFields(name string) *fakerOtherBadFields {
+ f := fakerOtherBadFields(name)
+ return &f
+}
+
+// fakerOtherSaverLoader implements SaverLoader methods.
+type fakerOtherSaverLoader string
+
+func (f *fakerOtherSaverLoader) StateTypeName() string {
+ return string(*f)
+}
+
+func (f *fakerOtherSaverLoader) StateFields() []string {
+ return nil
+}
+
+func (f *fakerOtherSaverLoader) StateSave(m state.Sink) {}
+
+func (f *fakerOtherSaverLoader) StateLoad(m state.Source) {}
+
+func newFakerOtherSaverLoader(name string) *fakerOtherSaverLoader {
+ f := fakerOtherSaverLoader(name)
+ return &f
+}
+
+func TestRegisterPrimitives(t *testing.T) {
+ for _, typeName := range []string{
+ "int",
+ "int8",
+ "int16",
+ "int32",
+ "int64",
+ "uint",
+ "uintptr",
+ "uint8",
+ "uint16",
+ "uint32",
+ "uint64",
+ "float32",
+ "float64",
+ "complex64",
+ "complex128",
+ "string",
+ } {
+ t.Run("struct/"+typeName, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Registering type %q did not panic", typeName)
+ }
+ }()
+ state.Register(&faker{
+ Name: typeName,
+ })
+ })
+ t.Run("other/"+typeName, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Registering type %q did not panic", typeName)
+ }
+ }()
+ state.Register(newFakerOther(typeName))
+ })
+ }
+}
+
+func TestRegisterBad(t *testing.T) {
+ const (
+ goodName = "foo"
+ firstField = "a"
+ secondField = "b"
+ )
+ for name, object := range map[string]state.Type{
+ "non-struct-with-fields": newFakerOtherBadFields(goodName),
+ "non-struct-with-saverloader": newFakerOtherSaverLoader(goodName),
+ "struct-without-saverloader": &faker{Name: goodName},
+ "non-struct-duplicate-with-struct": newFakerOther((new(alreadyRegisteredStruct)).StateTypeName()),
+ "non-struct-duplicate-with-non-struct": newFakerOther((new(alreadyRegisteredOther)).StateTypeName()),
+ "struct-duplicate-with-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredStruct)).StateTypeName()}},
+ "struct-duplicate-with-non-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredOther)).StateTypeName()}},
+ "struct-with-empty-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{""}}},
+ "struct-with-empty-field-and-non-empty": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, ""}}},
+ "struct-with-duplicate-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, firstField}}},
+ "struct-with-duplicate-field-and-non-dup": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, secondField, firstField}}},
+ } {
+ t.Run(name, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Registering object %#v did not panic", object)
+ }
+ }()
+ state.Register(object)
+ })
+
+ }
+}
diff --git a/pkg/state/tests/string_test.go b/pkg/state/tests/string_test.go
new file mode 100644
index 000000000..44f5a562c
--- /dev/null
+++ b/pkg/state/tests/string_test.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+)
+
+const nonEmptyString = "hello world"
+
+var allStrings = []string{
+ "",
+ nonEmptyString,
+ "\\0",
+}
+
+func TestString(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(allStrings))
+ runTestCases(t, false, "pointers", pointersTo(flatten(allStrings)))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(allStrings)))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allStrings))))
+}
diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go
new file mode 100644
index 000000000..bd2c2b399
--- /dev/null
+++ b/pkg/state/tests/struct.go
@@ -0,0 +1,65 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+type unregisteredEmptyStruct struct{}
+
+// typeOnlyEmptyStruct just implements the state.Type interface.
+type typeOnlyEmptyStruct struct{}
+
+func (*typeOnlyEmptyStruct) StateTypeName() string { return "registeredEmptyStruct" }
+
+func (*typeOnlyEmptyStruct) StateFields() []string { return nil }
+
+// +stateify savable
+type savableEmptyStruct struct{}
+
+// +stateify savable
+type emptyStructPointer struct {
+ nothing *struct{}
+}
+
+// +stateify savable
+type outerSame struct {
+ inner inner
+}
+
+// +stateify savable
+type outerFieldFirst struct {
+ inner inner
+ v int64
+}
+
+// +stateify savable
+type outerFieldSecond struct {
+ v int64
+ inner inner
+}
+
+// +stateify savable
+type outerArray struct {
+ inner [2]inner
+}
+
+// +stateify savable
+type inner struct {
+ v int64
+}
+
+// +stateify savable
+type system struct {
+ v1 interface{}
+ v2 interface{}
+}
diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go
new file mode 100644
index 000000000..de9d17aa7
--- /dev/null
+++ b/pkg/state/tests/struct_test.go
@@ -0,0 +1,89 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func TestEmptyStruct(t *testing.T) {
+ runTestCases(t, false, "plain", []interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ savableEmptyStruct{},
+ })
+ runTestCases(t, false, "pointers", pointersTo([]interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ savableEmptyStruct{},
+ }))
+ runTestCases(t, false, "interfaces-pass", interfacesTo([]interface{}{
+ // Only registered types can be dispatched via interfaces. All
+ // other types should fail, even if it is the empty struct.
+ savableEmptyStruct{},
+ }))
+ runTestCases(t, true, "interfaces-fail", interfacesTo([]interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ }))
+ runTestCases(t, false, "interfacesToPointers-pass", interfacesTo(pointersTo([]interface{}{
+ savableEmptyStruct{},
+ })))
+ runTestCases(t, true, "interfacesToPointers-fail", interfacesTo(pointersTo([]interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ })))
+
+ // Ensuring empty struct aliasing works.
+ es := emptyStructPointer{new(struct{})}
+ runTestCases(t, false, "empty-struct-pointers", []interface{}{
+ emptyStructPointer{},
+ es,
+ []emptyStructPointer{es, es}, // Same pointer.
+ })
+}
+
+func TestRegisterTypeOnlyStruct(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Register did not panic")
+ }
+ }()
+ state.Register((*typeOnlyEmptyStruct)(nil))
+}
+
+func TestEmbeddedPointers(t *testing.T) {
+ var (
+ ofs outerSame
+ of1 outerFieldFirst
+ of2 outerFieldSecond
+ oa outerArray
+ )
+
+ runTestCases(t, false, "embedded-pointers", []interface{}{
+ system{&ofs, &ofs.inner},
+ system{&ofs.inner, &ofs},
+ system{&of1, &of1.inner},
+ system{&of1.inner, &of1},
+ system{&of2, &of2.inner},
+ system{&of2.inner, &of2},
+ system{&oa, &oa.inner[0]},
+ system{&oa, &oa.inner[1]},
+ system{&oa.inner[0], &oa},
+ system{&oa.inner[1], &oa},
+ })
+}
diff --git a/pkg/state/tests/tests.go b/pkg/state/tests/tests.go
new file mode 100644
index 000000000..435a0e9db
--- /dev/null
+++ b/pkg/state/tests/tests.go
@@ -0,0 +1,215 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tests tests the state packages.
+package tests
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "math"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/pretty"
+)
+
+// discard is an implementation of wire.Writer.
+type discard struct{}
+
+// Write implements wire.Writer.Write.
+func (discard) Write(p []byte) (int, error) { return len(p), nil }
+
+// WriteByte implements wire.Writer.WriteByte.
+func (discard) WriteByte(byte) error { return nil }
+
+// checkEqual checks if two objects are equal.
+//
+// N.B. This only handles one level of dereferences for NaN. Otherwise we
+// would need to fork the entire implementation of reflect.DeepEqual.
+func checkEqual(root, loadedValue interface{}) bool {
+ if reflect.DeepEqual(root, loadedValue) {
+ return true
+ }
+
+ // NaN is not equal to itself. We handle the case of raw floating point
+ // primitives here, but don't handle this case nested.
+ rf32, ok1 := root.(float32)
+ lf32, ok2 := loadedValue.(float32)
+ if ok1 && ok2 && math.IsNaN(float64(rf32)) && math.IsNaN(float64(lf32)) {
+ return true
+ }
+ rf64, ok1 := root.(float64)
+ lf64, ok2 := loadedValue.(float64)
+ if ok1 && ok2 && math.IsNaN(rf64) && math.IsNaN(lf64) {
+ return true
+ }
+
+ // Same real for complex numbers.
+ rc64, ok1 := root.(complex64)
+ lc64, ok2 := root.(complex64)
+ if ok1 && ok2 {
+ return checkEqual(real(rc64), real(lc64)) && checkEqual(imag(rc64), imag(lc64))
+ }
+ rc128, ok1 := root.(complex128)
+ lc128, ok2 := root.(complex128)
+ if ok1 && ok2 {
+ return checkEqual(real(rc128), real(lc128)) && checkEqual(imag(rc128), imag(lc128))
+ }
+
+ return false
+}
+
+// runTestCases runs a test for each object in objects.
+func runTestCases(t *testing.T, shouldFail bool, prefix string, objects []interface{}) {
+ t.Helper()
+ for i, root := range objects {
+ t.Run(fmt.Sprintf("%s%d", prefix, i), func(t *testing.T) {
+ t.Logf("Original object:\n%#v", root)
+
+ // Save the passed object.
+ saveBuffer := &bytes.Buffer{}
+ saveObjectPtr := reflect.New(reflect.TypeOf(root))
+ saveObjectPtr.Elem().Set(reflect.ValueOf(root))
+ saveStats, err := state.Save(context.Background(), saveBuffer, saveObjectPtr.Interface())
+ if err != nil {
+ if shouldFail {
+ return
+ }
+ t.Fatalf("Save failed unexpectedly: %v", err)
+ }
+
+ // Dump the serialized proto to aid with debugging.
+ var ppBuf bytes.Buffer
+ t.Logf("Raw state:\n%v", saveBuffer.Bytes())
+ if err := pretty.PrintText(&ppBuf, bytes.NewReader(saveBuffer.Bytes())); err != nil {
+ // We don't count this as a test failure if we
+ // have shouldFail set, but we will count as a
+ // failure if we were not expecting to fail.
+ if !shouldFail {
+ t.Errorf("PrettyPrint(html=false) failed unexpected: %v", err)
+ }
+ }
+ if err := pretty.PrintHTML(discard{}, bytes.NewReader(saveBuffer.Bytes())); err != nil {
+ // See above.
+ if !shouldFail {
+ t.Errorf("PrettyPrint(html=true) failed unexpected: %v", err)
+ }
+ }
+ t.Logf("Encoded state:\n%s", ppBuf.String())
+ t.Logf("Save stats:\n%s", saveStats.String())
+
+ // Load a new copy of the object.
+ loadObjectPtr := reflect.New(reflect.TypeOf(root))
+ loadStats, err := state.Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface())
+ if err != nil {
+ if shouldFail {
+ return
+ }
+ t.Fatalf("Load failed unexpectedly: %v", err)
+ }
+
+ // Compare the values.
+ loadedValue := loadObjectPtr.Elem().Interface()
+ if !checkEqual(root, loadedValue) {
+ if shouldFail {
+ return
+ }
+ t.Fatalf("Objects differ:\n\toriginal: %#v\n\tloaded: %#v\n", root, loadedValue)
+ }
+
+ // Everything went okay. Is that good?
+ if shouldFail {
+ t.Fatalf("This test was expected to fail, but didn't.")
+ }
+ t.Logf("Load stats:\n%s", loadStats.String())
+
+ // Truncate half the bytes in the byte stream,
+ // and ensure that we can't restore. Then
+ // truncate only the final byte and ensure that
+ // we can't restore.
+ l := saveBuffer.Len()
+ halfReader := bytes.NewReader(saveBuffer.Bytes()[:l/2])
+ if _, err := state.Load(context.Background(), halfReader, loadObjectPtr.Interface()); err == nil {
+ t.Errorf("Load with half bytes succeeded unexpectedly.")
+ }
+ missingByteReader := bytes.NewReader(saveBuffer.Bytes()[:l-1])
+ if _, err := state.Load(context.Background(), missingByteReader, loadObjectPtr.Interface()); err == nil {
+ t.Errorf("Load with missing byte succeeded unexpectedly.")
+ }
+ })
+ }
+}
+
+// convert converts the slice to an []interface{}.
+func convert(v interface{}) (r []interface{}) {
+ s := reflect.ValueOf(v) // Must be slice.
+ for i := 0; i < s.Len(); i++ {
+ r = append(r, s.Index(i).Interface())
+ }
+ return r
+}
+
+// flatten flattens multiple slices.
+func flatten(vs ...interface{}) (r []interface{}) {
+ for _, v := range vs {
+ r = append(r, convert(v)...)
+ }
+ return r
+}
+
+// filter maps from one slice to another.
+func filter(vs interface{}, fn func(interface{}) (interface{}, bool)) (r []interface{}) {
+ s := reflect.ValueOf(vs)
+ for i := 0; i < s.Len(); i++ {
+ v, ok := fn(s.Index(i).Interface())
+ if ok {
+ r = append(r, v)
+ }
+ }
+ return r
+}
+
+// combine combines objects in two slices as specified.
+func combine(v1, v2 interface{}, fn func(_, _ interface{}) interface{}) (r []interface{}) {
+ s1 := reflect.ValueOf(v1)
+ s2 := reflect.ValueOf(v2)
+ for i := 0; i < s1.Len(); i++ {
+ for j := 0; j < s2.Len(); j++ {
+ // Combine using the given function.
+ r = append(r, fn(s1.Index(i).Interface(), s2.Index(j).Interface()))
+ }
+ }
+ return r
+}
+
+// pointersTo is a filter function that returns pointers.
+func pointersTo(vs interface{}) []interface{} {
+ return filter(vs, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o))
+ v.Elem().Set(reflect.ValueOf(o))
+ return v.Interface(), true
+ })
+}
+
+// interfacesTo is a filter function that returns interface objects.
+func interfacesTo(vs interface{}) []interface{} {
+ return filter(vs, func(o interface{}) (interface{}, bool) {
+ var v [1]interface{}
+ v[0] = o
+ return v, true
+ })
+}
diff --git a/pkg/state/types.go b/pkg/state/types.go
new file mode 100644
index 000000000..215ef80f8
--- /dev/null
+++ b/pkg/state/types.go
@@ -0,0 +1,361 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "reflect"
+ "sort"
+
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// assertValidType asserts that the type is valid.
+func assertValidType(name string, fields []string) {
+ if name == "" {
+ Failf("type has empty name")
+ }
+ fieldsCopy := make([]string, len(fields))
+ for i := 0; i < len(fields); i++ {
+ if fields[i] == "" {
+ Failf("field has empty name for type %q", name)
+ }
+ fieldsCopy[i] = fields[i]
+ }
+ sort.Slice(fieldsCopy, func(i, j int) bool {
+ return fieldsCopy[i] < fieldsCopy[j]
+ })
+ for i := range fieldsCopy {
+ if i > 0 && fieldsCopy[i-1] == fieldsCopy[i] {
+ Failf("duplicate field %q for type %s", fieldsCopy[i], name)
+ }
+ }
+}
+
+// typeEntry is an entry in the typeDatabase.
+type typeEntry struct {
+ ID typeID
+ wire.Type
+}
+
+// reconciledTypeEntry is a reconciled entry in the typeDatabase.
+type reconciledTypeEntry struct {
+ wire.Type
+ LocalType reflect.Type
+ FieldOrder []int
+}
+
+// typeEncodeDatabase is an internal TypeInfo database for encoding.
+type typeEncodeDatabase struct {
+ // byType maps by type to the typeEntry.
+ byType map[reflect.Type]*typeEntry
+
+ // lastID is the last used ID.
+ lastID typeID
+}
+
+// makeTypeEncodeDatabase makes a typeDatabase.
+func makeTypeEncodeDatabase() typeEncodeDatabase {
+ return typeEncodeDatabase{
+ byType: make(map[reflect.Type]*typeEntry),
+ }
+}
+
+// typeDecodeDatabase is an internal TypeInfo database for decoding.
+type typeDecodeDatabase struct {
+ // byID maps by ID to type.
+ byID []*reconciledTypeEntry
+
+ // pending are entries that are pending validation by Lookup. These
+ // will be reconciled with actual objects. Note that these will also be
+ // used to lookup types by name, since they may not be reconciled and
+ // there's little value to deleting from this map.
+ pending []*wire.Type
+}
+
+// makeTypeDecodeDatabase makes a typeDatabase.
+func makeTypeDecodeDatabase() typeDecodeDatabase {
+ return typeDecodeDatabase{}
+}
+
+// lookupNameFields extracts the name and fields from an object.
+func lookupNameFields(typ reflect.Type) (string, []string, bool) {
+ v := reflect.Zero(reflect.PtrTo(typ)).Interface()
+ t, ok := v.(Type)
+ if !ok {
+ // Is this a primitive?
+ if typ.Kind() == reflect.Interface {
+ return interfaceType, nil, true
+ }
+ name := typ.Name()
+ if _, ok := primitiveTypeDatabase[name]; !ok {
+ // This is not a known type, and not a primitive. The
+ // encoder may proceed for anonymous empty structs, or
+ // it may deference the type pointer and try again.
+ return "", nil, false
+ }
+ return name, nil, true
+ }
+ // Extract the name from the object.
+ name := t.StateTypeName()
+ fields := t.StateFields()
+ assertValidType(name, fields)
+ return name, fields, true
+}
+
+// Lookup looks up or registers the given object.
+//
+// The bool indicates whether this is an existing entry: false means the entry
+// did not exist, and true means the entry did exist. If this bool is false and
+// the returned typeEntry are nil, then the obj did not implement the Type
+// interface.
+func (tdb *typeEncodeDatabase) Lookup(typ reflect.Type) (*typeEntry, bool) {
+ te, ok := tdb.byType[typ]
+ if !ok {
+ // Lookup the type information.
+ name, fields, ok := lookupNameFields(typ)
+ if !ok {
+ // Empty structs may still be encoded, so let the
+ // caller decide what to do from here.
+ return nil, false
+ }
+
+ // Register the new type.
+ tdb.lastID++
+ te = &typeEntry{
+ ID: tdb.lastID,
+ Type: wire.Type{
+ Name: name,
+ Fields: fields,
+ },
+ }
+
+ // All done.
+ tdb.byType[typ] = te
+ return te, false
+ }
+ return te, true
+}
+
+// Register adds a typeID entry.
+func (tbd *typeDecodeDatabase) Register(typ *wire.Type) {
+ assertValidType(typ.Name, typ.Fields)
+ tbd.pending = append(tbd.pending, typ)
+}
+
+// LookupName looks up the type name by ID.
+func (tbd *typeDecodeDatabase) LookupName(id typeID) string {
+ if len(tbd.pending) < int(id) {
+ // This is likely an encoder error?
+ Failf("type ID %d not available", id)
+ }
+ return tbd.pending[id-1].Name
+}
+
+// LookupType looks up the type by ID.
+func (tbd *typeDecodeDatabase) LookupType(id typeID) reflect.Type {
+ name := tbd.LookupName(id)
+ typ, ok := globalTypeDatabase[name]
+ if !ok {
+ // If not available, see if it's primitive.
+ typ, ok = primitiveTypeDatabase[name]
+ if !ok && name == interfaceType {
+ // Matches the built-in interface type.
+ var i interface{}
+ return reflect.TypeOf(&i).Elem()
+ }
+ if !ok {
+ // The type is perhaps not registered?
+ Failf("type name %q is not available", name)
+ }
+ return typ // Primitive type.
+ }
+ return typ // Registered type.
+}
+
+// singleFieldOrder defines the field order for a single field.
+var singleFieldOrder = []int{0}
+
+// Lookup looks up or registers the given object.
+//
+// First, the typeID is searched to see if this has already been appropriately
+// reconciled. If no, then a reconcilation will take place that may result in a
+// field ordering. If a nil reconciledTypeEntry is returned from this method,
+// then the object does not support the Type interface.
+//
+// This method never returns nil.
+func (tbd *typeDecodeDatabase) Lookup(id typeID, typ reflect.Type) *reconciledTypeEntry {
+ if len(tbd.byID) > int(id) && tbd.byID[id-1] != nil {
+ // Already reconciled.
+ return tbd.byID[id-1]
+ }
+ // The ID has not been reconciled yet. That's fine. We need to make
+ // sure it aligns with the current provided object.
+ if len(tbd.pending) < int(id) {
+ // This id was never registered. Probably an encoder error?
+ Failf("typeDatabase does not contain id %d", id)
+ }
+ // Extract the pending info.
+ pending := tbd.pending[id-1]
+ // Grow the byID list.
+ if len(tbd.byID) < int(id) {
+ tbd.byID = append(tbd.byID, make([]*reconciledTypeEntry, int(id)-len(tbd.byID))...)
+ }
+ // Reconcile the type.
+ name, fields, ok := lookupNameFields(typ)
+ if !ok {
+ // Empty structs are decoded only when the type is nil. Since
+ // this isn't the case, we fail here.
+ Failf("unsupported type %q during decode; can't reconcile", pending.Name)
+ }
+ if name != pending.Name {
+ // Are these the same type? Print a helpful message as this may
+ // actually happen in practice if types change.
+ Failf("typeDatabase contains conflicting definitions for id %d: %s->%v (current) and %s->%v (existing)",
+ id, name, fields, pending.Name, pending.Fields)
+ }
+ rte := &reconciledTypeEntry{
+ Type: wire.Type{
+ Name: name,
+ Fields: fields,
+ },
+ LocalType: typ,
+ }
+ // If there are zero or one fields, then we skip allocating the field
+ // slice. There is special handling for decoding in this case. If the
+ // field name does not match, it will be caught in the general purpose
+ // code below.
+ if len(fields) != len(pending.Fields) {
+ Failf("type %q contains different fields: %v (decode) and %v (encode)",
+ name, fields, pending.Fields)
+ }
+ if len(fields) == 0 {
+ tbd.byID[id-1] = rte // Save.
+ return rte
+ }
+ if len(fields) == 1 && fields[0] == pending.Fields[0] {
+ tbd.byID[id-1] = rte // Save.
+ rte.FieldOrder = singleFieldOrder
+ return rte
+ }
+ // For each field in the current object's information, match it to a
+ // field in the destination object. We know from the assertion above
+ // and the insertion on insertion to pending that neither field
+ // contains any duplicates.
+ fieldOrder := make([]int, len(fields))
+ for i, name := range fields {
+ fieldOrder[i] = -1 // Sentinel.
+ // Is it an exact match?
+ if pending.Fields[i] == name {
+ fieldOrder[i] = i
+ continue
+ }
+ // Find the matching field.
+ for j, otherName := range pending.Fields {
+ if name == otherName {
+ fieldOrder[i] = j
+ break
+ }
+ }
+ if fieldOrder[i] == -1 {
+ // The type name matches but we are lacking some common fields.
+ Failf("type %q has mismatched fields: %v (decode) and %v (encode)",
+ name, fields, pending.Fields)
+ }
+ }
+ // The type has been reeconciled.
+ rte.FieldOrder = fieldOrder
+ tbd.byID[id-1] = rte
+ return rte
+}
+
+// interfaceType defines all interfaces.
+const interfaceType = "interface"
+
+// primitiveTypeDatabase is a set of fixed types.
+var primitiveTypeDatabase = func() map[string]reflect.Type {
+ r := make(map[string]reflect.Type)
+ for _, t := range []reflect.Type{
+ reflect.TypeOf(false),
+ reflect.TypeOf(int(0)),
+ reflect.TypeOf(int8(0)),
+ reflect.TypeOf(int16(0)),
+ reflect.TypeOf(int32(0)),
+ reflect.TypeOf(int64(0)),
+ reflect.TypeOf(uint(0)),
+ reflect.TypeOf(uintptr(0)),
+ reflect.TypeOf(uint8(0)),
+ reflect.TypeOf(uint16(0)),
+ reflect.TypeOf(uint32(0)),
+ reflect.TypeOf(uint64(0)),
+ reflect.TypeOf(""),
+ reflect.TypeOf(float32(0.0)),
+ reflect.TypeOf(float64(0.0)),
+ reflect.TypeOf(complex64(0.0)),
+ reflect.TypeOf(complex128(0.0)),
+ } {
+ r[t.Name()] = t
+ }
+ return r
+}()
+
+// globalTypeDatabase is used for dispatching interfaces on decode.
+var globalTypeDatabase = map[string]reflect.Type{}
+
+// Register registers a type.
+//
+// This must be called on init and only done once.
+func Register(t Type) {
+ name := t.StateTypeName()
+ fields := t.StateFields()
+ assertValidType(name, fields)
+ // Register must always be called on pointers.
+ typ := reflect.TypeOf(t)
+ if typ.Kind() != reflect.Ptr {
+ Failf("Register must be called on pointers")
+ }
+ typ = typ.Elem()
+ if typ.Kind() == reflect.Struct {
+ // All registered structs must implement SaverLoader. We allow
+ // the registration is non-struct types with just the Type
+ // interface, but we need to call StateSave/StateLoad methods
+ // on aggregate types.
+ if _, ok := t.(SaverLoader); !ok {
+ Failf("struct %T does not implement SaverLoader", t)
+ }
+ } else {
+ // Non-structs must not have any fields. We don't support
+ // calling StateSave/StateLoad methods on any non-struct types.
+ // If custom behavior is required, these types should be
+ // wrapped in a structure of some kind.
+ if len(fields) != 0 {
+ Failf("non-struct %T has non-zero fields %v", t, fields)
+ }
+ // We don't allow non-structs to implement StateSave/StateLoad
+ // methods, because they won't be called and it's confusing.
+ if _, ok := t.(SaverLoader); ok {
+ Failf("non-struct %T implements SaverLoader", t)
+ }
+ }
+ if _, ok := primitiveTypeDatabase[name]; ok {
+ Failf("conflicting primitiveTypeDatabase entry for %T: used by primitive", t)
+ }
+ if _, ok := globalTypeDatabase[name]; ok {
+ Failf("conflicting globalTypeDatabase entries for %T: name conflict", t)
+ }
+ if name == interfaceType {
+ Failf("conflicting name for %T: matches interfaceType", t)
+ }
+ globalTypeDatabase[name] = typ
+}
diff --git a/pkg/state/wire/BUILD b/pkg/state/wire/BUILD
new file mode 100644
index 000000000..311b93dcb
--- /dev/null
+++ b/pkg/state/wire/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "wire",
+ srcs = ["wire.go"],
+ marshal = False,
+ stateify = False,
+ visibility = ["//:sandbox"],
+ deps = ["//pkg/gohacks"],
+)
diff --git a/pkg/state/wire/wire.go b/pkg/state/wire/wire.go
new file mode 100644
index 000000000..93dee6740
--- /dev/null
+++ b/pkg/state/wire/wire.go
@@ -0,0 +1,970 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package wire contains a few basic types that can be composed to serialize
+// graph information for the state package. This package defines the wire
+// protocol.
+//
+// Note that these types are careful about how they implement the relevant
+// interfaces (either value receiver or pointer receiver), so that native-sized
+// types, such as integers and simple pointers, can fit inside the interface
+// object.
+//
+// This package also uses panic as control flow, so called should be careful to
+// wrap calls in appropriate handlers.
+//
+// Testing for this package is driven by the state test package.
+package wire
+
+import (
+ "fmt"
+ "io"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
+)
+
+// Reader is the required reader interface.
+type Reader interface {
+ io.Reader
+ ReadByte() (byte, error)
+}
+
+// Writer is the required writer interface.
+type Writer interface {
+ io.Writer
+ WriteByte(byte) error
+}
+
+// readFull is a utility. The equivalent is not needed for Write, but the API
+// contract dictates that it must always complete all bytes given or return an
+// error.
+func readFull(r io.Reader, p []byte) {
+ for done := 0; done < len(p); {
+ n, err := r.Read(p[done:])
+ done += n
+ if n == 0 && err != nil {
+ panic(err)
+ }
+ }
+}
+
+// Object is a generic object.
+type Object interface {
+ // save saves the given object.
+ //
+ // Panic is used for error control flow.
+ save(Writer)
+
+ // load loads a new object of the given type.
+ //
+ // Panic is used for error control flow.
+ load(Reader) Object
+}
+
+// Bool is a boolean.
+type Bool bool
+
+// loadBool loads an object of type Bool.
+func loadBool(r Reader) Bool {
+ b := loadUint(r)
+ return Bool(b == 1)
+}
+
+// save implements Object.save.
+func (b Bool) save(w Writer) {
+ var v Uint
+ if b {
+ v = 1
+ } else {
+ v = 0
+ }
+ v.save(w)
+}
+
+// load implements Object.load.
+func (Bool) load(r Reader) Object { return loadBool(r) }
+
+// Int is a signed integer.
+//
+// This uses varint encoding.
+type Int int64
+
+// loadInt loads an object of type Int.
+func loadInt(r Reader) Int {
+ u := loadUint(r)
+ x := Int(u >> 1)
+ if u&1 != 0 {
+ x = ^x
+ }
+ return x
+}
+
+// save implements Object.save.
+func (i Int) save(w Writer) {
+ u := Uint(i) << 1
+ if i < 0 {
+ u = ^u
+ }
+ u.save(w)
+}
+
+// load implements Object.load.
+func (Int) load(r Reader) Object { return loadInt(r) }
+
+// Uint is an unsigned integer.
+type Uint uint64
+
+// loadUint loads an object of type Uint.
+func loadUint(r Reader) Uint {
+ var (
+ u Uint
+ s uint
+ )
+ for i := 0; i <= 9; i++ {
+ b, err := r.ReadByte()
+ if err != nil {
+ panic(err)
+ }
+ if b < 0x80 {
+ if i == 9 && b > 1 {
+ panic("overflow")
+ }
+ u |= Uint(b) << s
+ return u
+ }
+ u |= Uint(b&0x7f) << s
+ s += 7
+ }
+ panic("unreachable")
+}
+
+// save implements Object.save.
+func (u Uint) save(w Writer) {
+ for u >= 0x80 {
+ if err := w.WriteByte(byte(u) | 0x80); err != nil {
+ panic(err)
+ }
+ u >>= 7
+ }
+ if err := w.WriteByte(byte(u)); err != nil {
+ panic(err)
+ }
+}
+
+// load implements Object.load.
+func (Uint) load(r Reader) Object { return loadUint(r) }
+
+// Float32 is a 32-bit floating point number.
+type Float32 float32
+
+// loadFloat32 loads an object of type Float32.
+func loadFloat32(r Reader) Float32 {
+ n := loadUint(r)
+ return Float32(math.Float32frombits(uint32(n)))
+}
+
+// save implements Object.save.
+func (f Float32) save(w Writer) {
+ n := Uint(math.Float32bits(float32(f)))
+ n.save(w)
+}
+
+// load implements Object.load.
+func (Float32) load(r Reader) Object { return loadFloat32(r) }
+
+// Float64 is a 64-bit floating point number.
+type Float64 float64
+
+// loadFloat64 loads an object of type Float64.
+func loadFloat64(r Reader) Float64 {
+ n := loadUint(r)
+ return Float64(math.Float64frombits(uint64(n)))
+}
+
+// save implements Object.save.
+func (f Float64) save(w Writer) {
+ n := Uint(math.Float64bits(float64(f)))
+ n.save(w)
+}
+
+// load implements Object.load.
+func (Float64) load(r Reader) Object { return loadFloat64(r) }
+
+// Complex64 is a 64-bit complex number.
+type Complex64 complex128
+
+// loadComplex64 loads an object of type Complex64.
+func loadComplex64(r Reader) Complex64 {
+ re := loadFloat32(r)
+ im := loadFloat32(r)
+ return Complex64(complex(float32(re), float32(im)))
+}
+
+// save implements Object.save.
+func (c *Complex64) save(w Writer) {
+ re := Float32(real(*c))
+ im := Float32(imag(*c))
+ re.save(w)
+ im.save(w)
+}
+
+// load implements Object.load.
+func (*Complex64) load(r Reader) Object {
+ c := loadComplex64(r)
+ return &c
+}
+
+// Complex128 is a 128-bit complex number.
+type Complex128 complex128
+
+// loadComplex128 loads an object of type Complex128.
+func loadComplex128(r Reader) Complex128 {
+ re := loadFloat64(r)
+ im := loadFloat64(r)
+ return Complex128(complex(float64(re), float64(im)))
+}
+
+// save implements Object.save.
+func (c *Complex128) save(w Writer) {
+ re := Float64(real(*c))
+ im := Float64(imag(*c))
+ re.save(w)
+ im.save(w)
+}
+
+// load implements Object.load.
+func (*Complex128) load(r Reader) Object {
+ c := loadComplex128(r)
+ return &c
+}
+
+// String is a string.
+type String string
+
+// loadString loads an object of type String.
+func loadString(r Reader) String {
+ l := loadUint(r)
+ p := make([]byte, l)
+ readFull(r, p)
+ return String(gohacks.StringFromImmutableBytes(p))
+}
+
+// save implements Object.save.
+func (s *String) save(w Writer) {
+ l := Uint(len(*s))
+ l.save(w)
+ p := gohacks.ImmutableBytesFromString(string(*s))
+ _, err := w.Write(p) // Must write all bytes.
+ if err != nil {
+ panic(err)
+ }
+}
+
+// load implements Object.load.
+func (*String) load(r Reader) Object {
+ s := loadString(r)
+ return &s
+}
+
+// Dot is a kind of reference: one of Index and FieldName.
+type Dot interface {
+ isDot()
+}
+
+// Index is a reference resolution.
+type Index uint32
+
+func (Index) isDot() {}
+
+// FieldName is a reference resolution.
+type FieldName string
+
+func (*FieldName) isDot() {}
+
+// Ref is a reference to an object.
+type Ref struct {
+ // Root is the root object.
+ Root Uint
+
+ // Dots is the set of traversals required from the Root object above.
+ // Note that this will be stored in reverse order for efficiency.
+ Dots []Dot
+
+ // Type is the base type for the root object. This is non-nil iff Dots
+ // is non-zero length (that is, this is a complex reference). This is
+ // not *strictly* necessary, but can be used to simplify decoding.
+ Type TypeSpec
+}
+
+// loadRef loads an object of type Ref (abstract).
+func loadRef(r Reader) Ref {
+ ref := Ref{
+ Root: loadUint(r),
+ }
+ l := loadUint(r)
+ ref.Dots = make([]Dot, l)
+ for i := 0; i < int(l); i++ {
+ // Disambiguate between an Index (non-negative) and a field
+ // name (negative). This does some space and avoids a dedicate
+ // loadDot function. See Ref.save for the other side.
+ d := loadInt(r)
+ if d >= 0 {
+ ref.Dots[i] = Index(d)
+ continue
+ }
+ p := make([]byte, -d)
+ readFull(r, p)
+ fieldName := FieldName(gohacks.StringFromImmutableBytes(p))
+ ref.Dots[i] = &fieldName
+ }
+ if l != 0 {
+ // Only if dots is non-zero.
+ ref.Type = loadTypeSpec(r)
+ }
+ return ref
+}
+
+// save implements Object.save.
+func (r *Ref) save(w Writer) {
+ r.Root.save(w)
+ l := Uint(len(r.Dots))
+ l.save(w)
+ for _, d := range r.Dots {
+ // See LoadRef. We use non-negative numbers to encode Index
+ // objects and negative numbers to encode field lengths.
+ switch x := d.(type) {
+ case Index:
+ i := Int(x)
+ i.save(w)
+ case *FieldName:
+ d := Int(-len(*x))
+ d.save(w)
+ p := gohacks.ImmutableBytesFromString(string(*x))
+ if _, err := w.Write(p); err != nil {
+ panic(err)
+ }
+ default:
+ panic("unknown dot implementation")
+ }
+ }
+ if l != 0 {
+ // See above.
+ saveTypeSpec(w, r.Type)
+ }
+}
+
+// load implements Object.load.
+func (*Ref) load(r Reader) Object {
+ ref := loadRef(r)
+ return &ref
+}
+
+// Nil is a primitive zero value of any type.
+type Nil struct{}
+
+// loadNil loads an object of type Nil.
+func loadNil(r Reader) Nil {
+ return Nil{}
+}
+
+// save implements Object.save.
+func (Nil) save(w Writer) {}
+
+// load implements Object.load.
+func (Nil) load(r Reader) Object { return loadNil(r) }
+
+// Slice is a slice value.
+type Slice struct {
+ Length Uint
+ Capacity Uint
+ Ref Ref
+}
+
+// loadSlice loads an object of type Slice.
+func loadSlice(r Reader) Slice {
+ return Slice{
+ Length: loadUint(r),
+ Capacity: loadUint(r),
+ Ref: loadRef(r),
+ }
+}
+
+// save implements Object.save.
+func (s *Slice) save(w Writer) {
+ s.Length.save(w)
+ s.Capacity.save(w)
+ s.Ref.save(w)
+}
+
+// load implements Object.load.
+func (*Slice) load(r Reader) Object {
+ s := loadSlice(r)
+ return &s
+}
+
+// Array is an array value.
+type Array struct {
+ Contents []Object
+}
+
+// loadArray loads an object of type Array.
+func loadArray(r Reader) Array {
+ l := loadUint(r)
+ if l == 0 {
+ // Note that there isn't a single object available to encode
+ // the type of, so we need this additional branch.
+ return Array{}
+ }
+ // All the objects here have the same type, so use dynamic dispatch
+ // only once. All other objects will automatically take the same type
+ // as the first object.
+ contents := make([]Object, l)
+ v := Load(r)
+ contents[0] = v
+ for i := 1; i < int(l); i++ {
+ contents[i] = v.load(r)
+ }
+ return Array{
+ Contents: contents,
+ }
+}
+
+// save implements Object.save.
+func (a *Array) save(w Writer) {
+ l := Uint(len(a.Contents))
+ l.save(w)
+ if l == 0 {
+ // See LoadArray.
+ return
+ }
+ // See above.
+ Save(w, a.Contents[0])
+ for i := 1; i < int(l); i++ {
+ a.Contents[i].save(w)
+ }
+}
+
+// load implements Object.load.
+func (*Array) load(r Reader) Object {
+ a := loadArray(r)
+ return &a
+}
+
+// Map is a map value.
+type Map struct {
+ Keys []Object
+ Values []Object
+}
+
+// loadMap loads an object of type Map.
+func loadMap(r Reader) Map {
+ l := loadUint(r)
+ if l == 0 {
+ // See LoadArray.
+ return Map{}
+ }
+ // See type dispatch notes in Array.
+ keys := make([]Object, l)
+ values := make([]Object, l)
+ k := Load(r)
+ v := Load(r)
+ keys[0] = k
+ values[0] = v
+ for i := 1; i < int(l); i++ {
+ keys[i] = k.load(r)
+ values[i] = v.load(r)
+ }
+ return Map{
+ Keys: keys,
+ Values: values,
+ }
+}
+
+// save implements Object.save.
+func (m *Map) save(w Writer) {
+ l := Uint(len(m.Keys))
+ if int(l) != len(m.Values) {
+ panic(fmt.Sprintf("mismatched keys (%d) Aand values (%d)", len(m.Keys), len(m.Values)))
+ }
+ l.save(w)
+ if l == 0 {
+ // See LoadArray.
+ return
+ }
+ // See above.
+ Save(w, m.Keys[0])
+ Save(w, m.Values[0])
+ for i := 1; i < int(l); i++ {
+ m.Keys[i].save(w)
+ m.Values[i].save(w)
+ }
+}
+
+// load implements Object.load.
+func (*Map) load(r Reader) Object {
+ m := loadMap(r)
+ return &m
+}
+
+// TypeSpec is a type dereference.
+type TypeSpec interface {
+ isTypeSpec()
+}
+
+// TypeID is a concrete type ID.
+type TypeID Uint
+
+func (TypeID) isTypeSpec() {}
+
+// TypeSpecPointer is a pointer type.
+type TypeSpecPointer struct {
+ Type TypeSpec
+}
+
+func (*TypeSpecPointer) isTypeSpec() {}
+
+// TypeSpecArray is an array type.
+type TypeSpecArray struct {
+ Count Uint
+ Type TypeSpec
+}
+
+func (*TypeSpecArray) isTypeSpec() {}
+
+// TypeSpecSlice is a slice type.
+type TypeSpecSlice struct {
+ Type TypeSpec
+}
+
+func (*TypeSpecSlice) isTypeSpec() {}
+
+// TypeSpecMap is a map type.
+type TypeSpecMap struct {
+ Key TypeSpec
+ Value TypeSpec
+}
+
+func (*TypeSpecMap) isTypeSpec() {}
+
+// TypeSpecNil is an empty type.
+type TypeSpecNil struct{}
+
+func (TypeSpecNil) isTypeSpec() {}
+
+// TypeSpec types.
+//
+// These use a distinct encoding on the wire, as they are used only in the
+// interface object. They are decoded through the dedicated loadTypeSpec and
+// saveTypeSpec functions.
+const (
+ typeSpecTypeID Uint = iota
+ typeSpecPointer
+ typeSpecArray
+ typeSpecSlice
+ typeSpecMap
+ typeSpecNil
+)
+
+// loadTypeSpec loads TypeSpec values.
+func loadTypeSpec(r Reader) TypeSpec {
+ switch hdr := loadUint(r); hdr {
+ case typeSpecTypeID:
+ return TypeID(loadUint(r))
+ case typeSpecPointer:
+ return &TypeSpecPointer{
+ Type: loadTypeSpec(r),
+ }
+ case typeSpecArray:
+ return &TypeSpecArray{
+ Count: loadUint(r),
+ Type: loadTypeSpec(r),
+ }
+ case typeSpecSlice:
+ return &TypeSpecSlice{
+ Type: loadTypeSpec(r),
+ }
+ case typeSpecMap:
+ return &TypeSpecMap{
+ Key: loadTypeSpec(r),
+ Value: loadTypeSpec(r),
+ }
+ case typeSpecNil:
+ return TypeSpecNil{}
+ default:
+ // This is not a valid stream?
+ panic(fmt.Errorf("unknown header: %d", hdr))
+ }
+}
+
+// saveTypeSpec saves TypeSpec values.
+func saveTypeSpec(w Writer, t TypeSpec) {
+ switch x := t.(type) {
+ case TypeID:
+ typeSpecTypeID.save(w)
+ Uint(x).save(w)
+ case *TypeSpecPointer:
+ typeSpecPointer.save(w)
+ saveTypeSpec(w, x.Type)
+ case *TypeSpecArray:
+ typeSpecArray.save(w)
+ x.Count.save(w)
+ saveTypeSpec(w, x.Type)
+ case *TypeSpecSlice:
+ typeSpecSlice.save(w)
+ saveTypeSpec(w, x.Type)
+ case *TypeSpecMap:
+ typeSpecMap.save(w)
+ saveTypeSpec(w, x.Key)
+ saveTypeSpec(w, x.Value)
+ case TypeSpecNil:
+ typeSpecNil.save(w)
+ default:
+ // This should not happen?
+ panic(fmt.Errorf("unknown type %T", t))
+ }
+}
+
+// Interface is an interface value.
+type Interface struct {
+ Type TypeSpec
+ Value Object
+}
+
+// loadInterface loads an object of type Interface.
+func loadInterface(r Reader) Interface {
+ return Interface{
+ Type: loadTypeSpec(r),
+ Value: Load(r),
+ }
+}
+
+// save implements Object.save.
+func (i *Interface) save(w Writer) {
+ saveTypeSpec(w, i.Type)
+ Save(w, i.Value)
+}
+
+// load implements Object.load.
+func (*Interface) load(r Reader) Object {
+ i := loadInterface(r)
+ return &i
+}
+
+// Type is type information.
+type Type struct {
+ Name string
+ Fields []string
+}
+
+// loadType loads an object of type Type.
+func loadType(r Reader) Type {
+ name := string(loadString(r))
+ l := loadUint(r)
+ fields := make([]string, l)
+ for i := 0; i < int(l); i++ {
+ fields[i] = string(loadString(r))
+ }
+ return Type{
+ Name: name,
+ Fields: fields,
+ }
+}
+
+// save implements Object.save.
+func (t *Type) save(w Writer) {
+ s := String(t.Name)
+ s.save(w)
+ l := Uint(len(t.Fields))
+ l.save(w)
+ for i := 0; i < int(l); i++ {
+ s := String(t.Fields[i])
+ s.save(w)
+ }
+}
+
+// load implements Object.load.
+func (*Type) load(r Reader) Object {
+ t := loadType(r)
+ return &t
+}
+
+// multipleObjects is a special type for serializing multiple objects.
+type multipleObjects []Object
+
+// loadMultipleObjects loads a series of objects.
+func loadMultipleObjects(r Reader) multipleObjects {
+ l := loadUint(r)
+ m := make(multipleObjects, l)
+ for i := 0; i < int(l); i++ {
+ m[i] = Load(r)
+ }
+ return m
+}
+
+// save implements Object.save.
+func (m *multipleObjects) save(w Writer) {
+ l := Uint(len(*m))
+ l.save(w)
+ for i := 0; i < int(l); i++ {
+ Save(w, (*m)[i])
+ }
+}
+
+// load implements Object.load.
+func (*multipleObjects) load(r Reader) Object {
+ m := loadMultipleObjects(r)
+ return &m
+}
+
+// noObjects represents no objects.
+type noObjects struct{}
+
+// loadNoObjects loads a sentinel.
+func loadNoObjects(r Reader) noObjects { return noObjects{} }
+
+// save implements Object.save.
+func (noObjects) save(w Writer) {}
+
+// load implements Object.load.
+func (noObjects) load(r Reader) Object { return loadNoObjects(r) }
+
+// Struct is a basic composite value.
+type Struct struct {
+ TypeID TypeID
+ fields Object // Optionally noObjects or *multipleObjects.
+}
+
+// Field returns a pointer to the given field slot.
+//
+// This must be called after Alloc.
+func (s *Struct) Field(i int) *Object {
+ if fields, ok := s.fields.(*multipleObjects); ok {
+ return &((*fields)[i])
+ }
+ if _, ok := s.fields.(noObjects); ok {
+ // Alloc may be optionally called; can't call twice.
+ panic("Field called inappropriately, wrong Alloc?")
+ }
+ return &s.fields
+}
+
+// Alloc allocates the given number of fields.
+//
+// This must be called before Add and Save.
+//
+// Precondition: slots must be positive.
+func (s *Struct) Alloc(slots int) {
+ switch {
+ case slots == 0:
+ s.fields = noObjects{}
+ case slots == 1:
+ // Leave it alone.
+ case slots > 1:
+ fields := make(multipleObjects, slots)
+ s.fields = &fields
+ default:
+ // Violates precondition.
+ panic(fmt.Sprintf("Alloc called with negative slots %d?", slots))
+ }
+}
+
+// Fields returns the number of fields.
+func (s *Struct) Fields() int {
+ switch x := s.fields.(type) {
+ case *multipleObjects:
+ return len(*x)
+ case noObjects:
+ return 0
+ default:
+ return 1
+ }
+}
+
+// loadStruct loads an object of type Struct.
+func loadStruct(r Reader) Struct {
+ return Struct{
+ TypeID: TypeID(loadUint(r)),
+ fields: Load(r),
+ }
+}
+
+// save implements Object.save.
+//
+// Precondition: Alloc must have been called, and the fields all filled in
+// appropriately. See Alloc and Add for more details.
+func (s *Struct) save(w Writer) {
+ Uint(s.TypeID).save(w)
+ Save(w, s.fields)
+}
+
+// load implements Object.load.
+func (*Struct) load(r Reader) Object {
+ s := loadStruct(r)
+ return &s
+}
+
+// Object types.
+//
+// N.B. Be careful about changing the order or introducing new elements in the
+// middle here. This is part of the wire format and shouldn't change.
+const (
+ typeBool Uint = iota
+ typeInt
+ typeUint
+ typeFloat32
+ typeFloat64
+ typeNil
+ typeRef
+ typeString
+ typeSlice
+ typeArray
+ typeMap
+ typeStruct
+ typeNoObjects
+ typeMultipleObjects
+ typeInterface
+ typeComplex64
+ typeComplex128
+ typeType
+)
+
+// Save saves the given object.
+//
+// +checkescape all
+//
+// N.B. This function will panic on error.
+func Save(w Writer, obj Object) {
+ switch x := obj.(type) {
+ case Bool:
+ typeBool.save(w)
+ x.save(w)
+ case Int:
+ typeInt.save(w)
+ x.save(w)
+ case Uint:
+ typeUint.save(w)
+ x.save(w)
+ case Float32:
+ typeFloat32.save(w)
+ x.save(w)
+ case Float64:
+ typeFloat64.save(w)
+ x.save(w)
+ case Nil:
+ typeNil.save(w)
+ x.save(w)
+ case *Ref:
+ typeRef.save(w)
+ x.save(w)
+ case *String:
+ typeString.save(w)
+ x.save(w)
+ case *Slice:
+ typeSlice.save(w)
+ x.save(w)
+ case *Array:
+ typeArray.save(w)
+ x.save(w)
+ case *Map:
+ typeMap.save(w)
+ x.save(w)
+ case *Struct:
+ typeStruct.save(w)
+ x.save(w)
+ case noObjects:
+ typeNoObjects.save(w)
+ x.save(w)
+ case *multipleObjects:
+ typeMultipleObjects.save(w)
+ x.save(w)
+ case *Interface:
+ typeInterface.save(w)
+ x.save(w)
+ case *Type:
+ typeType.save(w)
+ x.save(w)
+ case *Complex64:
+ typeComplex64.save(w)
+ x.save(w)
+ case *Complex128:
+ typeComplex128.save(w)
+ x.save(w)
+ default:
+ panic(fmt.Errorf("unknown type: %#v", obj))
+ }
+}
+
+// Load loads a new object.
+//
+// +checkescape all
+//
+// N.B. This function will panic on error.
+func Load(r Reader) Object {
+ switch hdr := loadUint(r); hdr {
+ case typeBool:
+ return loadBool(r)
+ case typeInt:
+ return loadInt(r)
+ case typeUint:
+ return loadUint(r)
+ case typeFloat32:
+ return loadFloat32(r)
+ case typeFloat64:
+ return loadFloat64(r)
+ case typeNil:
+ return loadNil(r)
+ case typeRef:
+ return ((*Ref)(nil)).load(r) // Escapes.
+ case typeString:
+ return ((*String)(nil)).load(r) // Escapes.
+ case typeSlice:
+ return ((*Slice)(nil)).load(r) // Escapes.
+ case typeArray:
+ return ((*Array)(nil)).load(r) // Escapes.
+ case typeMap:
+ return ((*Map)(nil)).load(r) // Escapes.
+ case typeStruct:
+ return ((*Struct)(nil)).load(r) // Escapes.
+ case typeNoObjects: // Special for struct.
+ return loadNoObjects(r)
+ case typeMultipleObjects: // Special for struct.
+ return ((*multipleObjects)(nil)).load(r) // Escapes.
+ case typeInterface:
+ return ((*Interface)(nil)).load(r) // Escapes.
+ case typeComplex64:
+ return ((*Complex64)(nil)).load(r) // Escapes.
+ case typeComplex128:
+ return ((*Complex128)(nil)).load(r) // Escapes.
+ case typeType:
+ return ((*Type)(nil)).load(r) // Escapes.
+ default:
+ // This is not a valid stream?
+ panic(fmt.Errorf("unknown header: %d", hdr))
+ }
+}
+
+// LoadUint loads a single unsigned integer.
+//
+// N.B. This function will panic on error.
+func LoadUint(r Reader) uint64 {
+ return uint64(loadUint(r))
+}
+
+// SaveUint saves a single unsigned integer.
+//
+// N.B. This function will panic on error.
+func SaveUint(w Writer, v uint64) {
+ Uint(v).save(w)
+}
diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD
index af3538ef0..dae9b3b3e 100644
--- a/runsc/cmd/BUILD
+++ b/runsc/cmd/BUILD
@@ -45,7 +45,7 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/platform",
- "//pkg/state",
+ "//pkg/state/pretty",
"//pkg/state/statefile",
"//pkg/sync",
"//pkg/unet",
diff --git a/runsc/cmd/statefile.go b/runsc/cmd/statefile.go
index e6f1907da..daed9e728 100644
--- a/runsc/cmd/statefile.go
+++ b/runsc/cmd/statefile.go
@@ -20,7 +20,7 @@ import (
"os"
"github.com/google/subcommands"
- "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/pretty"
"gvisor.dev/gvisor/pkg/state/statefile"
"gvisor.dev/gvisor/runsc/flag"
)
@@ -105,8 +105,14 @@ func (s *Statefile) Execute(_ context.Context, f *flag.FlagSet, args ...interfac
if err != nil {
Fatalf("error parsing statefile: %v", err)
}
- if err := state.PrettyPrint(output, rc, s.html); err != nil {
- Fatalf("error printing state: %v", err)
+ if s.html {
+ if err := pretty.PrintHTML(output, rc); err != nil {
+ Fatalf("error printing state: %v", err)
+ }
+ } else {
+ if err := pretty.PrintText(output, rc); err != nil {
+ Fatalf("error printing state: %v", err)
+ }
}
return subcommands.ExitSuccess
}
diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go
index 571e9a6e6..f8def4823 100644
--- a/tools/checkescape/checkescape.go
+++ b/tools/checkescape/checkescape.go
@@ -88,7 +88,7 @@ const (
testMagic = "// +mustescape:"
// exempt is the exemption annotation.
- exempt = "// escapes:"
+ exempt = "// escapes"
)
// escapingBuiltins are builtins known to escape.
@@ -546,7 +546,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
for _, cg := range f.Comments {
for _, c := range cg.List {
p := pass.Fset.Position(c.Slash)
- if strings.HasPrefix(c.Text, exempt) {
+ if strings.HasPrefix(strings.ToLower(c.Text), exempt) {
exemptions[LinePosition{
Filename: filepath.Base(p.Filename),
Line: p.Line,
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
index 309ee9c21..4f6ed208a 100644
--- a/tools/go_stateify/main.go
+++ b/tools/go_stateify/main.go
@@ -103,7 +103,7 @@ type scanFunctions struct {
// skipped if nil.
//
// Fields tagged nosave are skipped.
-func scanFields(ss *ast.StructType, fn scanFunctions) {
+func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) {
if ss.Fields.List == nil {
// No fields.
return
@@ -127,7 +127,16 @@ func scanFields(ss *ast.StructType, fn scanFunctions) {
continue
}
- switch tag := extractStateTag(field.Tag); tag {
+ // Is this a anonymous struct? If yes, then continue the
+ // recursion with the given prefix. We don't pay attention to
+ // any tags on the top-level struct field.
+ tag := extractStateTag(field.Tag)
+ if anon, ok := field.Type.(*ast.StructType); ok && tag == "" {
+ scanFields(anon, name+".", fn)
+ continue
+ }
+
+ switch tag {
case "zerovalue":
if fn.zerovalue != nil {
fn.zerovalue(name)
@@ -201,28 +210,12 @@ func main() {
// initCalls is dumped at the end.
var initCalls []string
- // Declare our emission closures.
+ // Common closures.
emitRegister := func(name string) {
- initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *fullPkg, name, name, name, name))
+ initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name))
}
emitZeroCheck := func(name string) {
- fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { m.Failf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, name, name)
- }
- emitLoadValue := func(name, typName string) {
- fmt.Fprintf(outputFile, " m.LoadValue(\"%s\", new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", name, typName, camelCased(name), typName)
- }
- emitLoad := func(name string) {
- fmt.Fprintf(outputFile, " m.Load(\"%s\", &x.%s)\n", name, name)
- }
- emitLoadWait := func(name string) {
- fmt.Fprintf(outputFile, " m.LoadWait(\"%s\", &x.%s)\n", name, name)
- }
- emitSaveValue := func(name, typName string) {
- fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name))
- fmt.Fprintf(outputFile, " m.SaveValue(\"%s\", %s)\n", name, name)
- }
- emitSave := func(name string) {
- fmt.Fprintf(outputFile, " m.Save(\"%s\", &x.%s)\n", name, name)
+ fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { %sFailf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, statePrefix, name, name)
}
// Automated warning.
@@ -329,87 +322,140 @@ func main() {
continue
}
- // Only generate code for types marked
- // "// +stateify savable" in one of the proceeding
- // comment lines.
+ // Only generate code for types marked "// +stateify
+ // savable" in one of the proceeding comment lines. If
+ // the line is marked "// +stateify type" then only
+ // generate type information and register the type.
if d.Doc == nil {
continue
}
- savable := false
+ var (
+ generateTypeInfo = false
+ generateSaverLoader = false
+ )
for _, l := range d.Doc.List {
if l.Text == "// +stateify savable" {
- savable = true
+ generateTypeInfo = true
+ generateSaverLoader = true
break
}
+ if l.Text == "// +stateify type" {
+ generateTypeInfo = true
+ }
}
- if !savable {
+ if !generateTypeInfo && !generateSaverLoader {
continue
}
for _, gs := range d.Specs {
ts := gs.(*ast.TypeSpec)
- switch ts.Type.(type) {
- case *ast.InterfaceType, *ast.ChanType, *ast.FuncType, *ast.ParenExpr, *ast.StarExpr:
- // Don't register.
- break
+ switch x := ts.Type.(type) {
case *ast.StructType:
maybeEmitImports()
- ss := ts.Type.(*ast.StructType)
+ // Record the slot for each field.
+ fieldCount := 0
+ fields := make(map[string]int)
+ emitField := func(name string) {
+ fmt.Fprintf(outputFile, " \"%s\",\n", name)
+ fields[name] = fieldCount
+ fieldCount++
+ }
+ emitFieldValue := func(name string, _ string) {
+ emitField(name)
+ }
+ emitLoadValue := func(name, typName string) {
+ fmt.Fprintf(outputFile, " m.LoadValue(%d, new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", fields[name], typName, camelCased(name), typName)
+ }
+ emitLoad := func(name string) {
+ fmt.Fprintf(outputFile, " m.Load(%d, &x.%s)\n", fields[name], name)
+ }
+ emitLoadWait := func(name string) {
+ fmt.Fprintf(outputFile, " m.LoadWait(%d, &x.%s)\n", fields[name], name)
+ }
+ emitSaveValue := func(name, typName string) {
+ fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name))
+ fmt.Fprintf(outputFile, " m.SaveValue(%d, %s)\n", fields[name], name)
+ }
+ emitSave := func(name string) {
+ fmt.Fprintf(outputFile, " m.Save(%d, &x.%s)\n", fields[name], name)
+ }
+
+ // Generate the type name method.
+ fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
+ fmt.Fprintf(outputFile, "}\n\n")
+
+ // Generate the fields method.
+ fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return []string{\n")
+ scanFields(x, "", scanFunctions{
+ normal: emitField,
+ wait: emitField,
+ value: emitFieldValue,
+ })
+ fmt.Fprintf(outputFile, " }\n")
+ fmt.Fprintf(outputFile, "}\n\n")
- // Define beforeSave if a definition was not found. This
- // prevents the code from compiling if a custom beforeSave
- // was defined in a file not provided to this binary and
- // prevents inherited methods from being called multiple times
- // by overriding them.
- if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok {
- fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n", ts.Name.Name)
+ // Define beforeSave if a definition was not found. This prevents
+ // the code from compiling if a custom beforeSave was defined in a
+ // file not provided to this binary and prevents inherited methods
+ // from being called multiple times by overriding them.
+ if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n\n", ts.Name.Name)
}
// Generate the save method.
- fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix)
- fmt.Fprintf(outputFile, " x.beforeSave()\n")
- scanFields(ss, scanFunctions{zerovalue: emitZeroCheck})
- scanFields(ss, scanFunctions{value: emitSaveValue})
- scanFields(ss, scanFunctions{normal: emitSave, wait: emitSave})
- fmt.Fprintf(outputFile, "}\n\n")
+ //
+ // N.B. For historical reasons, we perform the value saves first,
+ // and perform the value loads last. There should be no dependency
+ // on this specific behavior, but the ability to specify slots
+ // allows a manual implementation to be order-dependent.
+ if generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) StateSave(m %sSink) {\n", ts.Name.Name, statePrefix)
+ fmt.Fprintf(outputFile, " x.beforeSave()\n")
+ scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck})
+ scanFields(x, "", scanFunctions{value: emitSaveValue})
+ scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave})
+ fmt.Fprintf(outputFile, "}\n\n")
+ }
- // Define afterLoad if a definition was not found. We do this
- // for the same reason that we do it for beforeSave.
+ // Define afterLoad if a definition was not found. We do this for
+ // the same reason that we do it for beforeSave.
_, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}]
- if !hasAfterLoad {
- fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n", ts.Name.Name)
+ if !hasAfterLoad && generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n\n", ts.Name.Name)
}
// Generate the load method.
//
- // Note that the manual loads always follow the
- // automated loads.
- fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix)
- scanFields(ss, scanFunctions{normal: emitLoad, wait: emitLoadWait})
- scanFields(ss, scanFunctions{value: emitLoadValue})
- if hasAfterLoad {
- // The call to afterLoad is made conditionally, because when
- // AfterLoad is called, the object encodes a dependency on
- // referred objects (i.e. fields). This means that afterLoad
- // will not be called until the other afterLoads are called.
- fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n")
+ // N.B. See the comment above for the save method.
+ if generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) StateLoad(m %sSource) {\n", ts.Name.Name, statePrefix)
+ scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait})
+ scanFields(x, "", scanFunctions{value: emitLoadValue})
+ if hasAfterLoad {
+ // The call to afterLoad is made conditionally, because when
+ // AfterLoad is called, the object encodes a dependency on
+ // referred objects (i.e. fields). This means that afterLoad
+ // will not be called until the other afterLoads are called.
+ fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n")
+ }
+ fmt.Fprintf(outputFile, "}\n\n")
}
- fmt.Fprintf(outputFile, "}\n\n")
// Add to our registration.
emitRegister(ts.Name.Name)
+
case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType:
maybeEmitImports()
- _, val := resolveTypeName(ts.Name.Name, ts.Type)
-
- // Dispatch directly.
- fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix)
- fmt.Fprintf(outputFile, " m.SaveValue(\"\", (%s)(*x))\n", val)
+ // Generate the info methods.
+ fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
fmt.Fprintf(outputFile, "}\n\n")
- fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix)
- fmt.Fprintf(outputFile, " m.LoadValue(\"\", new(%s), func(y interface{}) { *x = (%s)(y.(%s)) })\n", val, ts.Name.Name, val)
+ fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return nil\n")
fmt.Fprintf(outputFile, "}\n\n")
// See above.