diff options
Diffstat (limited to 'pkg')
47 files changed, 5045 insertions, 2436 deletions
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go index 5f52cbe74..b094c5662 100644 --- a/pkg/compressio/compressio.go +++ b/pkg/compressio/compressio.go @@ -346,20 +346,22 @@ func (p *pool) schedule(c *chunk, callback func(*chunk) error) error { } } -// reader chunks reads and decompresses. -type reader struct { +// Reader is a compressed reader. +type Reader struct { pool // in is the source. in io.Reader } +var _ io.Reader = (*Reader)(nil) + // NewReader returns a new compressed reader. If key is non-nil, the data stream // is assumed to contain expected hash values, which will be compared against // hash values computed from the compressed bytes. See package comments for // details. -func NewReader(in io.Reader, key []byte) (io.Reader, error) { - r := &reader{ +func NewReader(in io.Reader, key []byte) (*Reader, error) { + r := &Reader{ in: in, } @@ -394,8 +396,19 @@ var errNewBuffer = errors.New("buffer ready") // ErrHashMismatch is returned if the hash does not match. var ErrHashMismatch = errors.New("hash mismatch") +// ReadByte implements wire.Reader.ReadByte. +func (r *Reader) ReadByte() (byte, error) { + var p [1]byte + n, err := r.Read(p[:]) + if n != 1 { + return p[0], err + } + // Suppress EOF. + return p[0], nil +} + // Read implements io.Reader.Read. -func (r *reader) Read(p []byte) (int, error) { +func (r *Reader) Read(p []byte) (int, error) { r.mu.Lock() defer r.mu.Unlock() @@ -551,8 +564,8 @@ func (r *reader) Read(p []byte) (int, error) { return done, nil } -// writer chunks and schedules writes. -type writer struct { +// Writer is a compressed writer. +type Writer struct { pool // out is the underlying writer. @@ -562,6 +575,8 @@ type writer struct { closed bool } +var _ io.Writer = (*Writer)(nil) + // NewWriter returns a new compressed writer. If key is non-nil, hash values are // generated and written out for compressed bytes. See package comments for // details. @@ -569,8 +584,8 @@ type writer struct { // The recommended chunkSize is on the order of 1M. Extra memory may be // buffered (in the form of read-ahead, or buffered writes), and is limited to // O(chunkSize * [1+GOMAXPROCS]). -func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.WriteCloser, error) { - w := &writer{ +func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer, error) { + w := &Writer{ pool: pool{ chunkSize: chunkSize, buf: bufPool.Get().(*bytes.Buffer), @@ -597,7 +612,7 @@ func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.Write } // flush writes a single buffer. -func (w *writer) flush(c *chunk) error { +func (w *Writer) flush(c *chunk) error { // Prefix each chunk with a length; this allows the reader to safely // limit reads while buffering. l := uint32(c.compressed.Len()) @@ -624,8 +639,23 @@ func (w *writer) flush(c *chunk) error { return nil } +// WriteByte implements wire.Writer.WriteByte. +// +// Note that this implementation is necessary on the object itself, as an +// interface-based dispatch cannot tell whether the array backing the slice +// escapes, therefore the all bytes written will generate an escape. +func (w *Writer) WriteByte(b byte) error { + var p [1]byte + p[0] = b + n, err := w.Write(p[:]) + if n != 1 { + return err + } + return nil +} + // Write implements io.Writer.Write. -func (w *writer) Write(p []byte) (int, error) { +func (w *Writer) Write(p []byte) (int, error) { w.mu.Lock() defer w.mu.Unlock() @@ -710,7 +740,7 @@ func (w *writer) Write(p []byte) (int, error) { } // Close implements io.Closer.Close. -func (w *writer) Close() error { +func (w *Writer) Close() error { w.mu.Lock() defer w.mu.Unlock() diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD index 798a65eca..35683fe98 100644 --- a/pkg/gohacks/BUILD +++ b/pkg/gohacks/BUILD @@ -7,5 +7,6 @@ go_library( srcs = [ "gohacks_unsafe.go", ], + stateify = False, visibility = ["//:sandbox"], ) diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go index 0d07da3b1..f4a4c33d3 100644 --- a/pkg/ilist/list.go +++ b/pkg/ilist/list.go @@ -90,7 +90,7 @@ func (l *List) Back() Element { // // NOTE: This is an O(n) operation. func (l *List) Len() (count int) { - for e := l.Front(); e != nil; e = e.Next() { + for e := l.Front(); e != nil; e = (ElementMapper{}.linkerFor(e)).Next() { count++ } return count @@ -182,13 +182,13 @@ func (l *List) Remove(e Element) { if prev != nil { ElementMapper{}.linkerFor(prev).SetNext(next) - } else { + } else if l.head == e { l.head = next } if next != nil { ElementMapper{}.linkerFor(next).SetPrev(prev) - } else { + } else if l.tail == e { l.tail = prev } diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 1510a7c26..25fe1921b 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -200,6 +200,7 @@ go_library( "//pkg/sentry/vfs", "//pkg/state", "//pkg/state/statefile", + "//pkg/state/wire", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 554a42e05..2177b785a 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -34,7 +34,6 @@ package kernel import ( "errors" "fmt" - "io" "path/filepath" "sync/atomic" "time" @@ -73,6 +72,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/wire" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -417,7 +417,7 @@ func (k *Kernel) Init(args InitKernelArgs) error { // SaveTo saves the state of k to w. // // Preconditions: The kernel must be paused throughout the call to SaveTo. -func (k *Kernel) SaveTo(w io.Writer) error { +func (k *Kernel) SaveTo(w wire.Writer) error { saveStart := time.Now() ctx := k.SupervisorContext() @@ -473,18 +473,18 @@ func (k *Kernel) SaveTo(w io.Writer) error { // // N.B. This will also be saved along with the full kernel save below. cpuidStart := time.Now() - if err := state.Save(k.SupervisorContext(), w, k.FeatureSet(), nil); err != nil { + if _, err := state.Save(k.SupervisorContext(), w, k.FeatureSet()); err != nil { return err } log.Infof("CPUID save took [%s].", time.Since(cpuidStart)) // Save the kernel state. kernelStart := time.Now() - var stats state.Stats - if err := state.Save(k.SupervisorContext(), w, k, &stats); err != nil { + stats, err := state.Save(k.SupervisorContext(), w, k) + if err != nil { return err } - log.Infof("Kernel save stats: %s", &stats) + log.Infof("Kernel save stats: %s", stats.String()) log.Infof("Kernel save took [%s].", time.Since(kernelStart)) // Save the memory file's state. @@ -629,7 +629,7 @@ func (ts *TaskSet) unregisterEpollWaiters() { } // LoadFrom returns a new Kernel loaded from args. -func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error { +func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clocks) error { loadStart := time.Now() initAppCores := k.applicationCores @@ -640,7 +640,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // don't need to explicitly install it in the Kernel. cpuidStart := time.Now() var features cpuid.FeatureSet - if err := state.Load(k.SupervisorContext(), r, &features, nil); err != nil { + if _, err := state.Load(k.SupervisorContext(), r, &features); err != nil { return err } log.Infof("CPUID load took [%s].", time.Since(cpuidStart)) @@ -655,11 +655,11 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // Load the kernel state. kernelStart := time.Now() - var stats state.Stats - if err := state.Load(k.SupervisorContext(), r, k, &stats); err != nil { + stats, err := state.Load(k.SupervisorContext(), r, k) + if err != nil { return err } - log.Infof("Kernel load stats: %s", &stats) + log.Infof("Kernel load stats: %s", stats.String()) log.Infof("Kernel load took [%s].", time.Since(kernelStart)) // rootNetworkNamespace should be populated after loading the state file. diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index a9836ba71..e1fcb175f 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -92,6 +92,7 @@ go_library( "//pkg/sentry/platform", "//pkg/sentry/usage", "//pkg/state", + "//pkg/state/wire", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go index f8385c146..78317fa35 100644 --- a/pkg/sentry/pgalloc/save_restore.go +++ b/pkg/sentry/pgalloc/save_restore.go @@ -26,11 +26,12 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/wire" "gvisor.dev/gvisor/pkg/usermem" ) // SaveTo writes f's state to the given stream. -func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error { +func (f *MemoryFile) SaveTo(ctx context.Context, w wire.Writer) error { // Wait for reclaim. f.mu.Lock() defer f.mu.Unlock() @@ -79,10 +80,10 @@ func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error { } // Save metadata. - if err := state.Save(ctx, w, &f.fileSize, nil); err != nil { + if _, err := state.Save(ctx, w, &f.fileSize); err != nil { return err } - if err := state.Save(ctx, w, &f.usage, nil); err != nil { + if _, err := state.Save(ctx, w, &f.usage); err != nil { return err } @@ -115,9 +116,9 @@ func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error { } // LoadFrom loads MemoryFile state from the given stream. -func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error { +func (f *MemoryFile) LoadFrom(ctx context.Context, r wire.Reader) error { // Load metadata. - if err := state.Load(ctx, r, &f.fileSize, nil); err != nil { + if _, err := state.Load(ctx, r, &f.fileSize); err != nil { return err } if err := f.file.Truncate(f.fileSize); err != nil { @@ -125,7 +126,7 @@ func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error { } newMappings := make([]uintptr, f.fileSize>>chunkShift) f.mappings.Store(newMappings) - if err := state.Load(ctx, r, &f.usage, nil); err != nil { + if _, err := state.Load(ctx, r, &f.usage); err != nil { return err } diff --git a/pkg/state/BUILD b/pkg/state/BUILD index 2b1350135..089b3bbef 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -1,9 +1,47 @@ -load("//tools:defs.bzl", "go_library", "go_test", "proto_library") +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) go_template_instance( + name = "pending_list", + out = "pending_list.go", + package = "state", + prefix = "pending", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*objectEncodeState", + "ElementMapper": "pendingMapper", + "Linker": "*pendingEntry", + }, +) + +go_template_instance( + name = "deferred_list", + out = "deferred_list.go", + package = "state", + prefix = "deferred", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*objectEncodeState", + "ElementMapper": "deferredMapper", + "Linker": "*deferredEntry", + }, +) + +go_template_instance( + name = "complete_list", + out = "complete_list.go", + package = "state", + prefix = "complete", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*objectDecodeState", + "Linker": "*objectDecodeState", + }, +) + +go_template_instance( name = "addr_range", out = "addr_range.go", package = "state", @@ -29,7 +67,7 @@ go_template_instance( types = { "Key": "uintptr", "Range": "addrRange", - "Value": "reflect.Value", + "Value": "*objectEncodeState", "Functions": "addrSetFunctions", }, ) @@ -39,32 +77,24 @@ go_library( srcs = [ "addr_range.go", "addr_set.go", + "complete_list.go", "decode.go", + "decode_unsafe.go", + "deferred_list.go", "encode.go", "encode_unsafe.go", - "map.go", - "printer.go", + "pending_list.go", "state.go", + "state_norace.go", + "state_race.go", "stats.go", + "types.go", ], marshal = False, stateify = False, visibility = ["//:sandbox"], deps = [ - ":object_go_proto", - "@com_github_golang_protobuf//proto:go_default_library", + "//pkg/log", + "//pkg/state/wire", ], ) - -proto_library( - name = "object", - srcs = ["object.proto"], - visibility = ["//:sandbox"], -) - -go_test( - name = "state_test", - timeout = "long", - srcs = ["state_test.go"], - library = ":state", -) diff --git a/pkg/state/README.md b/pkg/state/README.md new file mode 100644 index 000000000..1aa401193 --- /dev/null +++ b/pkg/state/README.md @@ -0,0 +1,158 @@ +# State Encoding and Decoding + +The state package implements the encoding and decoding of data structures for +`go_stateify`. This package is designed for use cases other than the standard +encoding packages, e.g. `gob` and `json`. Principally: + +* This package operates on complex object graphs and accurately serializes and + restores all relationships. That is, you can have things like: intrusive + pointers, cycles, and pointer chains of arbitrary depths. These are not + handled appropriately by existing encoders. This is not an implementation + flaw: the formats themselves are not capable of representing these graphs, + as they can only generate directed trees. + +* This package allows installing order-dependent load callbacks and then + resolves that graph at load time, with cycle detection. Similarly, there is + no analogous feature possible in the standard encoders. + +* This package handles the resolution of interfaces, based on a registered + type name. For interface objects type information is saved in the serialized + format. This is generally true for `gob` as well, but it works differently. + +Here's an overview of how encoding and decoding works. + +## Encoding + +Encoding produces a `statefile`, which contains a list of chunks of the form +`(header, payload)`. The payload can either be some raw data, or a series of +encoded wire objects representing some object graph. All encoded objects are +defined in the `wire` subpackage. + +Encoding of an object graph begins with `encodeState.Save`. + +### 1. Memory Map & Encoding + +To discover relationships between potentially interdependent data structures +(for example, a struct may contain pointers to members of other data +structures), the encoder first walks the object graph and constructs a memory +map of the objects in the input graph. As this walk progresses, objects are +queued in the `pending` list and items are placed on the `deferred` list as they +are discovered. No single object will be encoded multiple times, but the +discovered relationships between objects may change as more parts of the overall +object graph are discovered. + +The encoder starts at the root object and recursively visits all reachable +objects, recording the address ranges containing the underlying data for each +object. This is stored as a segment set (`addrSet`), mapping address ranges to +the of the object occupying the range; see `encodeState.values`. Note that there +is special handling for zero-sized types and map objects during this process. + +Additionally, the encoder assigns each object a unique identifier which is used +to indicate relationships between objects in the statefile; see `objectID` in +`encode.go`. + +### 2. Type Serialization + +The enoder will subsequently serialize all information about discovered types, +including field names. These are used during decoding to reconcile these types +with other internally registered types. + +### 3. Object Serialization + +With a full address map, and all objects correctly encoded, all object encodings +are serialized. The assigned `objectID`s aren't explicitly encoded in the +statefile. The order of object messages in the stream determine their IDs. + +### Example + +Given the following data structure definitions: + +```go +type system struct { + o *outer + i *inner +} + +type outer struct { + a int64 + cn *container +} + +type container struct { + n uint64 + elem *inner +} + +type inner struct { + c container + x, y uint64 +} +``` + +Initialized like this: + +```go +o := outer{ + a: 10, + cn: nil, +} +i := inner{ + x: 20, + y: 30, + c: container{}, +} +s := system{ + o: &o, + i: &i, +} + +o.cn = &i.c +o.cn.elem = &i + +``` + +Encoding will produce an object stream like this: + +``` +g0r1 = struct{ + i: g0r3, + o: g0r2, +} +g0r2 = struct{ + a: 10, + cn: g0r3.c, +} +g0r3 = struct{ + c: struct{ + elem: g0r3, + n: 0u, + }, + x: 20u, + y: 30u, +} +``` + +Note how `g0r3.c` is correctly encoded as the underlying `container` object for +`inner.c`, and how the pointer from `outer.cn` points to it, despite `system.i` +being discovered after the pointer to it in `system.o.cn`. Also note that +decoding isn't strictly reliant on the order of encoded object stream, as long +as the relationship between objects are correctly encoded. + +## Decoding + +Decoding reads the statefile and reconstructs the object graph. Decoding begins +in `decodeState.Load`. Decoding is performed in a single pass over the object +stream in the statefile, and a subsequent pass over all deserialized objects is +done to fire off all loading callbacks in the correctly defined order. Note that +introducing cycles is possible here, but these are detected and an error will be +returned. + +Decoding is relatively straight forward. For most primitive values, the decoder +constructs an appropriate object and fills it with the values encoded in the +statefile. Pointers need special handling, as they must point to a value +allocated elsewhere. When values are constructed, the decoder indexes them by +their `objectID`s in `decodeState.objectsByID`. The target of pointers are +resolved by searching for the target in this index by their `objectID`; see +`decodeState.register`. For pointers to values inside another value (fields in a +pointer, elements of an array), the decoder uses the accessor path to walk to +the appropriate location; see `walkChild`. diff --git a/pkg/state/decode.go b/pkg/state/decode.go index 590c241a3..c9971cdf6 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -17,28 +17,49 @@ package state import ( "bytes" "context" - "encoding/binary" - "errors" "fmt" - "io" + "math" "reflect" - "sort" - "github.com/golang/protobuf/proto" - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" + "gvisor.dev/gvisor/pkg/state/wire" ) -// objectState represents an object that may be in the process of being +// internalCallback is a interface called on object completion. +// +// There are two implementations: objectDecodeState & userCallback. +type internalCallback interface { + // source returns the dependent object. May be nil. + source() *objectDecodeState + + // callbackRun executes the callback. + callbackRun() +} + +// userCallback is an implementation of internalCallback. +type userCallback func() + +// source implements internalCallback.source. +func (userCallback) source() *objectDecodeState { + return nil +} + +// callbackRun implements internalCallback.callbackRun. +func (uc userCallback) callbackRun() { + uc() +} + +// objectDecodeState represents an object that may be in the process of being // decoded. Specifically, it represents either a decoded object, or an an // interest in a future object that will be decoded. When that interest is // registered (via register), the storage for the object will be created, but // it will not be decoded until the object is encountered in the stream. -type objectState struct { +type objectDecodeState struct { // id is the id for this object. - // - // If this field is zero, then this is an anonymous (unregistered, - // non-reference primitive) object. This is immutable. - id uint64 + id objectID + + // typ is the id for this typeID. This may be zero if this is not a + // type-registered structure. + typ typeID // obj is the object. This may or may not be valid yet, depending on // whether complete returns true. However, regardless of whether the @@ -57,69 +78,52 @@ type objectState struct { // blockedBy is the number of dependencies this object has. blockedBy int - // blocking is a list of the objects blocked by this one. - blocking []*objectState + // callbacksInline is inline storage for callbacks. + callbacksInline [2]internalCallback // callbacks is a set of callbacks to execute on load. - callbacks []func() - - // path is the decoding path to the object. - path recoverable -} - -// complete indicates the object is complete. -func (os *objectState) complete() bool { - return os.blockedBy == 0 && len(os.callbacks) == 0 -} - -// checkComplete checks for completion. If the object is complete, pending -// callbacks will be executed and checkComplete will be called on downstream -// objects (those depending on this one). -func (os *objectState) checkComplete(stats *Stats) { - if os.blockedBy > 0 { - return - } - stats.Start(os.obj) + callbacks []internalCallback - // Fire all callbacks. - for _, fn := range os.callbacks { - fn() - } - os.callbacks = nil - - // Clear all blocked objects. - for _, other := range os.blocking { - other.blockedBy-- - other.checkComplete(stats) - } - os.blocking = nil - stats.Done() + completeEntry } -// waitFor queues a dependency on the given object. -func (os *objectState) waitFor(other *objectState, callback func()) { - os.blockedBy++ - other.blocking = append(other.blocking, os) - if callback != nil { - other.callbacks = append(other.callbacks, callback) +// addCallback adds a callback to the objectDecodeState. +func (ods *objectDecodeState) addCallback(ic internalCallback) { + if ods.callbacks == nil { + ods.callbacks = ods.callbacksInline[:0] } + ods.callbacks = append(ods.callbacks, ic) } // findCycleFor returns when the given object is found in the blocking set. -func (os *objectState) findCycleFor(target *objectState) []*objectState { - for _, other := range os.blocking { - if other == target { - return []*objectState{target} +func (ods *objectDecodeState) findCycleFor(target *objectDecodeState) []*objectDecodeState { + for _, ic := range ods.callbacks { + other := ic.source() + if other != nil && other == target { + return []*objectDecodeState{target} } else if childList := other.findCycleFor(target); childList != nil { return append(childList, other) } } - return nil + + // This should not occur. + Failf("no deadlock found?") + panic("unreachable") } // findCycle finds a dependency cycle. -func (os *objectState) findCycle() []*objectState { - return append(os.findCycleFor(os), os) +func (ods *objectDecodeState) findCycle() []*objectDecodeState { + return append(ods.findCycleFor(ods), ods) +} + +// source implements internalCallback.source. +func (ods *objectDecodeState) source() *objectDecodeState { + return ods +} + +// callbackRun implements internalCallback.callbackRun. +func (ods *objectDecodeState) callbackRun() { + ods.blockedBy-- } // decodeState is a graph of objects in the process of being decoded. @@ -137,30 +141,66 @@ type decodeState struct { // ctx is the decode context. ctx context.Context + // r is the input stream. + r wire.Reader + + // types is the type database. + types typeDecodeDatabase + // objectByID is the set of objects in progress. - objectsByID map[uint64]*objectState + objectsByID []*objectDecodeState // deferred are objects that have been read, by no interest has been // registered yet. These will be decoded once interest in registered. - deferred map[uint64]*pb.Object + deferred map[objectID]wire.Object - // outstanding is the number of outstanding objects. - outstanding uint32 + // pending is the set of objects that are not yet complete. + pending completeList - // r is the input stream. - r io.Reader - - // stats is the passed stats object. - stats *Stats - - // recoverable is the panic recover facility. - recoverable + // stats tracks time data. + stats Stats } // lookup looks up an object in decodeState or returns nil if no such object // has been previously registered. -func (ds *decodeState) lookup(id uint64) *objectState { - return ds.objectsByID[id] +func (ds *decodeState) lookup(id objectID) *objectDecodeState { + if len(ds.objectsByID) < int(id) { + return nil + } + return ds.objectsByID[id-1] +} + +// checkComplete checks for completion. +func (ds *decodeState) checkComplete(ods *objectDecodeState) bool { + // Still blocked? + if ods.blockedBy > 0 { + return false + } + + // Track stats if relevant. + if ods.callbacks != nil && ods.typ != 0 { + ds.stats.start(ods.typ) + defer ds.stats.done() + } + + // Fire all callbacks. + for _, ic := range ods.callbacks { + ic.callbackRun() + } + + // Mark completed. + cbs := ods.callbacks + ods.callbacks = nil + ds.pending.Remove(ods) + + // Recursively check others. + for _, ic := range cbs { + if other := ic.source(); other != nil && other.blockedBy == 0 { + ds.checkComplete(other) + } + } + + return true // All set. } // wait registers a dependency on an object. @@ -168,11 +208,8 @@ func (ds *decodeState) lookup(id uint64) *objectState { // As a special case, we always allow _useable_ references back to the first // decoding object because it may have fields that are already decoded. We also // allow trivial self reference, since they can be handled internally. -func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) { +func (ds *decodeState) wait(waiter *objectDecodeState, id objectID, callback func()) { switch id { - case 0: - // Nil pointer; nothing to wait for. - fallthrough case waiter.id: // Trivial self reference. fallthrough @@ -184,107 +221,188 @@ func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) { return } + // Mark as blocked. + waiter.blockedBy++ + // No nil can be returned here. - waiter.waitFor(ds.lookup(id), callback) + other := ds.lookup(id) + if callback != nil { + // Add the additional user callback. + other.addCallback(userCallback(callback)) + } + + // Mark waiter as unblocked. + other.addCallback(waiter) } // waitObject notes a blocking relationship. -func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) { - if rv, ok := p.Value.(*pb.Object_RefValue); ok { +func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, callback func()) { + if rv, ok := encoded.(*wire.Ref); ok && rv.Root != 0 { // Refs can encode pointers and maps. - ds.wait(os, rv.RefValue, callback) - } else if sv, ok := p.Value.(*pb.Object_SliceValue); ok { + ds.wait(ods, objectID(rv.Root), callback) + } else if sv, ok := encoded.(*wire.Slice); ok && sv.Ref.Root != 0 { // See decodeObject; we need to wait for the array (if non-nil). - ds.wait(os, sv.SliceValue.RefValue, callback) - } else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok { + ds.wait(ods, objectID(sv.Ref.Root), callback) + } else if iv, ok := encoded.(*wire.Interface); ok { // It's an interface (wait recurisvely). - ds.waitObject(os, iv.InterfaceValue.Value, callback) + ds.waitObject(ods, iv.Value, callback) } else if callback != nil { // Nothing to wait for: execute the callback immediately. callback() } } +// walkChild returns a child object from obj, given an accessor path. This is +// the decode-side equivalent to traverse in encode.go. +// +// For the purposes of this function, a child object is either a field within a +// struct or an array element, with one such indirection per element in +// path. The returned value may be an unexported field, so it may not be +// directly assignable. See unsafePointerTo. +func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value { + // See wire.Ref.Dots. The path here is specified in reverse order. + for i := len(path) - 1; i >= 0; i-- { + switch pc := path[i].(type) { + case *wire.FieldName: // Must be a pointer. + if obj.Kind() != reflect.Struct { + Failf("next component in child path is a field name, but the current object is not a struct. Path: %v, current obj: %#v", path, obj) + } + obj = obj.FieldByName(string(*pc)) + case wire.Index: // Embedded. + if obj.Kind() != reflect.Array { + Failf("next component in child path is an array index, but the current object is not an array. Path: %v, current obj: %#v", path, obj) + } + obj = obj.Index(int(pc)) + default: + panic("unreachable: switch should be exhaustive") + } + } + return obj +} + // register registers a decode with a type. // // This type is only used to instantiate a new object if it has not been -// registered previously. -func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState { - os, ok := ds.objectsByID[id] - if ok { - return os +// registered previously. This depends on the type provided if none is +// available in the object itself. +func (ds *decodeState) register(r *wire.Ref, typ reflect.Type) reflect.Value { + // Grow the objectsByID slice. + id := objectID(r.Root) + if len(ds.objectsByID) < int(id) { + ds.objectsByID = append(ds.objectsByID, make([]*objectDecodeState, int(id)-len(ds.objectsByID))...) + } + + // Does this object already exist? + ods := ds.objectsByID[id-1] + if ods != nil { + return walkChild(r.Dots, ods.obj) + } + + // Create the object. + if len(r.Dots) != 0 { + typ = ds.findType(r.Type) } + v := reflect.New(typ) + ods = &objectDecodeState{ + id: id, + obj: v.Elem(), + } + ds.objectsByID[id-1] = ods + ds.pending.PushBack(ods) - // Record in the object index. - if typ.Kind() == reflect.Map { - os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()} - } else { - os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()} + // Process any deferred objects & callbacks. + if encoded, ok := ds.deferred[id]; ok { + delete(ds.deferred, id) + ds.decodeObject(ods, ods.obj, encoded) } - ds.objectsByID[id] = os - if o, ok := ds.deferred[id]; ok { - // There is a deferred object. - delete(ds.deferred, id) // Free memory. - ds.decodeObject(os, os.obj, o, "", nil) - } else { - // There is no deferred object. - ds.outstanding++ + return walkChild(r.Dots, ods.obj) +} + +// objectDecoder is for decoding structs. +type objectDecoder struct { + // ds is decodeState. + ds *decodeState + + // ods is current object being decoded. + ods *objectDecodeState + + // reconciledTypeEntry is the reconciled type information. + rte *reconciledTypeEntry + + // encoded is the encoded object state. + encoded *wire.Struct +} + +// load is helper for the public methods on Source. +func (od *objectDecoder) load(slot int, objPtr reflect.Value, wait bool, fn func()) { + // Note that we have reconciled the type and may remap the fields here + // to match what's expected by the decoder. The "slot" parameter here + // is in terms of the local type, where the fields in the encoded + // object are in terms of the wire object's type, which might be in a + // different order (but will have the same fields). + v := *od.encoded.Field(od.rte.FieldOrder[slot]) + od.ds.decodeObject(od.ods, objPtr.Elem(), v) + if wait { + // Mark this individual object a blocker. + od.ds.waitObject(od.ods, v, fn) } +} - return os +// aterLoad implements Source.AfterLoad. +func (od *objectDecoder) afterLoad(fn func()) { + // Queue the local callback; this will execute when all of the above + // data dependencies have been cleared. + od.ods.addCallback(userCallback(fn)) } // decodeStruct decodes a struct value. -func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) { - // Set the fields. - m := Map{newInternalMap(nil, ds, os)} - defer internalMapPool.Put(m.internalMap) - for _, field := range s.Fields { - m.data = append(m.data, entry{ - name: field.Name, - object: field.Value, - }) - } - - // Sort the fields for efficient searching. - // - // Technically, these should already appear in sorted order in the - // state ordering, so this cost is effectively a single scan to ensure - // that the order is correct. - if len(m.data) > 1 { - sort.Slice(m.data, func(i, j int) bool { - return m.data[i].name < m.data[j].name - }) - } - - // Invoke the load; this will recursively decode other objects. - fns, ok := registeredTypes.lookupFns(obj.Addr().Type()) - if ok { - // Invoke the loader. - fns.invokeLoad(obj.Addr(), m) - } else if obj.NumField() == 0 { - // Allow anonymous empty structs. - return - } else { +func (ds *decodeState) decodeStruct(ods *objectDecodeState, obj reflect.Value, encoded *wire.Struct) { + if encoded.TypeID == 0 { + // Allow anonymous empty structs, but only if the encoded + // object also has no fields. + if encoded.Fields() == 0 && obj.NumField() == 0 { + return + } + // Propagate an error. - panic(fmt.Errorf("unregistered type %s", obj.Type())) + Failf("empty struct on wire %#v has field mismatch with type %q", encoded, obj.Type().Name()) + } + + // Lookup the object type. + rte := ds.types.Lookup(typeID(encoded.TypeID), obj.Type()) + ods.typ = typeID(encoded.TypeID) + + // Invoke the loader. + od := objectDecoder{ + ds: ds, + ods: ods, + rte: rte, + encoded: encoded, + } + ds.stats.start(ods.typ) + defer ds.stats.done() + if sl, ok := obj.Addr().Interface().(SaverLoader); ok { + // Note: may be a registered empty struct which does not + // implement the saver/loader interfaces. + sl.StateLoad(Source{internal: od}) } } // decodeMap decodes a map value. -func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) { +func (ds *decodeState) decodeMap(ods *objectDecodeState, obj reflect.Value, encoded *wire.Map) { if obj.IsNil() { + // See pointerTo. obj.Set(reflect.MakeMap(obj.Type())) } - for i := 0; i < len(m.Keys); i++ { + for i := 0; i < len(encoded.Keys); i++ { // Decode the objects. kv := reflect.New(obj.Type().Key()).Elem() vv := reflect.New(obj.Type().Elem()).Elem() - ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i) - ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface()) - ds.waitObject(os, m.Keys[i], nil) - ds.waitObject(os, m.Values[i], nil) + ds.decodeObject(ods, kv, encoded.Keys[i]) + ds.decodeObject(ods, vv, encoded.Values[i]) + ds.waitObject(ods, encoded.Keys[i], nil) + ds.waitObject(ods, encoded.Values[i], nil) // Set in the map. obj.SetMapIndex(kv, vv) @@ -292,271 +410,294 @@ func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) } // decodeArray decodes an array value. -func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) { - if len(a.Contents) != obj.Len() { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents))) +func (ds *decodeState) decodeArray(ods *objectDecodeState, obj reflect.Value, encoded *wire.Array) { + if len(encoded.Contents) != obj.Len() { + Failf("mismatching array length expect=%d, actual=%d", obj.Len(), len(encoded.Contents)) } // Decode the contents into the array. - for i := 0; i < len(a.Contents); i++ { - ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i) - ds.waitObject(os, a.Contents[i], nil) + for i := 0; i < len(encoded.Contents); i++ { + ds.decodeObject(ods, obj.Index(i), encoded.Contents[i]) + ds.waitObject(ods, encoded.Contents[i], nil) } } -// decodeInterface decodes an interface value. -func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) { - // Is this a nil value? - if i.Type == "" { - return // Just leave obj alone. +// findType finds the type for the given wire.TypeSpecs. +func (ds *decodeState) findType(t wire.TypeSpec) reflect.Type { + switch x := t.(type) { + case wire.TypeID: + typ := ds.types.LookupType(typeID(x)) + rte := ds.types.Lookup(typeID(x), typ) + return rte.LocalType + case *wire.TypeSpecPointer: + return reflect.PtrTo(ds.findType(x.Type)) + case *wire.TypeSpecArray: + return reflect.ArrayOf(int(x.Count), ds.findType(x.Type)) + case *wire.TypeSpecSlice: + return reflect.SliceOf(ds.findType(x.Type)) + case *wire.TypeSpecMap: + return reflect.MapOf(ds.findType(x.Key), ds.findType(x.Value)) + default: + // Should not happen. + Failf("unknown type %#v", t) } + panic("unreachable") +} - // Get the dispatchable type. This may not be used if the given - // reference has already been resolved, but if not we need to know the - // type to create. - t, ok := registeredTypes.lookupType(i.Type) - if !ok { - panic(fmt.Errorf("no valid type for %q", i.Type)) +// decodeInterface decodes an interface value. +func (ds *decodeState) decodeInterface(ods *objectDecodeState, obj reflect.Value, encoded *wire.Interface) { + if _, ok := encoded.Type.(wire.TypeSpecNil); ok { + // Special case; the nil object. Just decode directly, which + // will read nil from the wire (if encoded correctly). + ds.decodeObject(ods, obj, encoded.Value) + return } - if obj.Kind() != reflect.Map { - // Set the obj to be the given typed value; this actually sets - // obj to be a non-zero value -- namely, it inserts type - // information. There's no need to do this for maps. - obj.Set(reflect.Zero(t)) + // We now need to resolve the actual type. + typ := ds.findType(encoded.Type) + + // We need to imbue type information here, then we can proceed to + // decode normally. In order to avoid issues with setting value-types, + // we create a new non-interface version of this object. We will then + // set the interface object to be equal to whatever we decode. + origObj := obj + obj = reflect.New(typ).Elem() + defer origObj.Set(obj) + + // With the object now having sufficient type information to actually + // have Set called on it, we can proceed to decode the value. + ds.decodeObject(ods, obj, encoded.Value) +} + +// isFloatEq determines if x and y represent the same value. +func isFloatEq(x float64, y float64) bool { + switch { + case math.IsNaN(x): + return math.IsNaN(y) + case math.IsInf(x, 1): + return math.IsInf(y, 1) + case math.IsInf(x, -1): + return math.IsInf(y, -1) + default: + return x == y } +} - // Decode the dereferenced element; there is no need to wait here, as - // the interface object shares the current object state. - ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type) +// isComplexEq determines if x and y represent the same value. +func isComplexEq(x complex128, y complex128) bool { + return isFloatEq(real(x), real(y)) && isFloatEq(imag(x), imag(y)) } // decodeObject decodes a object value. -func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) { - ds.push(false, format, param) - ds.stats.Add(obj) - ds.stats.Start(obj) - - switch x := object.GetValue().(type) { - case *pb.Object_BoolValue: - obj.SetBool(x.BoolValue) - case *pb.Object_StringValue: - obj.SetString(string(x.StringValue)) - case *pb.Object_Int64Value: - obj.SetInt(x.Int64Value) - if obj.Int() != x.Int64Value { - panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type())) +func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, encoded wire.Object) { + switch x := encoded.(type) { + case wire.Nil: // Fast path: first. + // We leave obj alone here. That's because if obj represents an + // interface, it may have been imbued with type information in + // decodeInterface, and we don't want to destroy that. + case *wire.Ref: + // Nil pointers may be encoded in a "forceValue" context. For + // those we just leave it alone as the value will already be + // correct (nil). + if id := objectID(x.Root); id == 0 { + return } - case *pb.Object_Uint64Value: - obj.SetUint(x.Uint64Value) - if obj.Uint() != x.Uint64Value { - panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type())) - } - case *pb.Object_DoubleValue: - obj.SetFloat(x.DoubleValue) - if obj.Float() != x.DoubleValue { - panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type())) - } - case *pb.Object_RefValue: - // Resolve the pointer itself, even though the object may not - // be decoded yet. You need to use wait() in order to ensure - // that is the case. See wait above, and Map.Barrier. - if id := x.RefValue; id != 0 { - // Decoding the interface should have imparted type - // information, so from this point it's safe to resolve - // and use this dynamic information for actually - // creating the object in register. - // - // (For non-interfaces this is a no-op). - dyntyp := reflect.TypeOf(obj.Interface()) - if dyntyp.Kind() == reflect.Map { - // Remove the map object count here to avoid - // double counting, as this object will be - // counted again when it gets processed later. - // We do not add a reference count as the - // reference is artificial. - ds.stats.Remove(obj) - obj.Set(ds.register(id, dyntyp).obj) - } else if dyntyp.Kind() == reflect.Ptr { - ds.push(true /* dereference */, "", nil) - obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) - ds.pop() - } else { - obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) + + // Note that if this is a map type, we go through a level of + // indirection to allow for map aliasing. + if obj.Kind() == reflect.Map { + v := ds.register(x, obj.Type()) + if v.IsNil() { + // Note that we don't want to clobber the map + // if has already been decoded by decodeMap. We + // just make it so that we have a consistent + // reference when that eventually does happen. + v.Set(reflect.MakeMap(v.Type())) } - } else { - // We leave obj alone here. That's because if obj - // represents an interface, it may have been embued - // with type information in decodeInterface, and we - // don't want to destroy that information. + obj.Set(v) + return } - case *pb.Object_SliceValue: - // It's okay to slice the array here, since the contents will - // still be provided later on. These semantics are a bit - // strange but they are handled in the Map.Barrier properly. - // - // The special semantics of zero ref apply here too. - if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 { - v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem()) - obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity))) + + // Normal assignment: authoritative only if no dots. + v := ds.register(x, obj.Type().Elem()) + if v.IsValid() { + obj.Set(unsafePointerTo(v)) } - case *pb.Object_ArrayValue: - ds.decodeArray(os, obj, x.ArrayValue) - case *pb.Object_StructValue: - ds.decodeStruct(os, obj, x.StructValue) - case *pb.Object_MapValue: - ds.decodeMap(os, obj, x.MapValue) - case *pb.Object_InterfaceValue: - ds.decodeInterface(os, obj, x.InterfaceValue) - case *pb.Object_ByteArrayValue: - copyArray(obj, reflect.ValueOf(x.ByteArrayValue)) - case *pb.Object_Uint16ArrayValue: - // 16-bit slices are serialized as 32-bit slices. - // See object.proto for details. - s := x.Uint16ArrayValue.Values - t := obj.Slice(0, obj.Len()).Interface().([]uint16) - if len(t) != len(s) { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) + case wire.Bool: + obj.SetBool(bool(x)) + case wire.Int: + obj.SetInt(int64(x)) + if obj.Int() != int64(x) { + Failf("signed integer truncated from %v to %v", int64(x), obj.Int()) } - for i := range s { - t[i] = uint16(s[i]) + case wire.Uint: + obj.SetUint(uint64(x)) + if obj.Uint() != uint64(x) { + Failf("unsigned integer truncated from %v to %v", uint64(x), obj.Uint()) } - case *pb.Object_Uint32ArrayValue: - copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values)) - case *pb.Object_Uint64ArrayValue: - copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values)) - case *pb.Object_UintptrArrayValue: - copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0)))) - case *pb.Object_Int8ArrayValue: - copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0)))) - case *pb.Object_Int16ArrayValue: - // 16-bit slices are serialized as 32-bit slices. - // See object.proto for details. - s := x.Int16ArrayValue.Values - t := obj.Slice(0, obj.Len()).Interface().([]int16) - if len(t) != len(s) { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) + case wire.Float32: + obj.SetFloat(float64(x)) + case wire.Float64: + obj.SetFloat(float64(x)) + if !isFloatEq(obj.Float(), float64(x)) { + Failf("floating point number truncated from %v to %v", float64(x), obj.Float()) } - for i := range s { - t[i] = int16(s[i]) + case *wire.Complex64: + obj.SetComplex(complex128(*x)) + case *wire.Complex128: + obj.SetComplex(complex128(*x)) + if !isComplexEq(obj.Complex(), complex128(*x)) { + Failf("complex number truncated from %v to %v", complex128(*x), obj.Complex()) } - case *pb.Object_Int32ArrayValue: - copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values)) - case *pb.Object_Int64ArrayValue: - copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values)) - case *pb.Object_BoolArrayValue: - copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values)) - case *pb.Object_Float64ArrayValue: - copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values)) - case *pb.Object_Float32ArrayValue: - copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values)) + case *wire.String: + obj.SetString(string(*x)) + case *wire.Slice: + // See *wire.Ref above; same applies. + if id := objectID(x.Ref.Root); id == 0 { + return + } + // Note that it's fine to slice the array here and assume that + // contents will still be filled in later on. + typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type. + v := ds.register(&x.Ref, typ) + obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity))) + case *wire.Array: + ds.decodeArray(ods, obj, x) + case *wire.Struct: + ds.decodeStruct(ods, obj, x) + case *wire.Map: + ds.decodeMap(ods, obj, x) + case *wire.Interface: + ds.decodeInterface(ods, obj, x) default: // Shoud not happen, not propagated as an error. - panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type())) - } - - ds.stats.Done() - ds.pop() -} - -func copyArray(dest reflect.Value, src reflect.Value) { - if dest.Len() != src.Len() { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len())) + Failf("unknown object %#v for %q", encoded, obj.Type().Name()) } - reflect.Copy(dest, castSlice(src, dest.Type().Elem())) } -// Deserialize deserializes the object state. +// Load deserializes the object graph rooted at obj. // // This function may panic and should be run in safely(). -func (ds *decodeState) Deserialize(obj reflect.Value) { - ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()} - ds.outstanding = 1 // The root object. +func (ds *decodeState) Load(obj reflect.Value) { + ds.stats.init() + defer ds.stats.fini(func(id typeID) string { + return ds.types.LookupName(id) + }) + + // Create the root object. + ds.objectsByID = append(ds.objectsByID, &objectDecodeState{ + id: 1, + obj: obj, + }) + + // Read the number of objects. + lastID, object, err := ReadHeader(ds.r) + if err != nil { + Failf("header error: %w", err) + } + if !object { + Failf("object missing") + } + + // Decode all objects. + var ( + encoded wire.Object + ods *objectDecodeState + id = objectID(1) + tid = typeID(1) + ) + if err := safely(func() { + // Decode all objects in the stream. + // + // Note that the structure of this decoding loop should match + // the raw decoding loop in printer.go. + for id <= objectID(lastID) { + // Unmarshal the object. + encoded = wire.Load(ds.r) + + // Is this a type object? Handle inline. + if wt, ok := encoded.(*wire.Type); ok { + ds.types.Register(wt) + tid++ + encoded = nil + continue + } - // Decode all objects in the stream. - // - // See above, we never process objects while we have no outstanding - // interests (other than the very first object). - for id := uint64(1); ds.outstanding > 0; id++ { - os := ds.lookup(id) - ds.stats.Start(os.obj) - - o, err := ds.readObject() - if err != nil { - panic(err) - } + // Actually resolve the object. + ods = ds.lookup(id) + if ods != nil { + // Decode the object. + ds.decodeObject(ods, ods.obj, encoded) + } else { + // If an object hasn't had interest registered + // previously or isn't yet valid, we deferred + // decoding until interest is registered. + ds.deferred[id] = encoded + } - if os != nil { - // Decode the object. - ds.from = &os.path - ds.decodeObject(os, os.obj, o, "", nil) - ds.outstanding-- + // For error handling. + ods = nil + encoded = nil + id++ + } + }); err != nil { + // Include as much information as we can, taking into account + // the possible state transitions above. + if ods != nil { + Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err) + } else if encoded != nil { + Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err) } else { - // If an object hasn't had interest registered - // previously, we deferred decoding until interest is - // registered. - ds.deferred[id] = o + Failf("general decoding error: %w", err) } - - ds.stats.Done() - } - - // Check the zero-length header at the end. - length, object, err := ReadHeader(ds.r) - if err != nil { - panic(err) - } - if length != 0 { - panic(fmt.Sprintf("expected zero-length terminal, got %d", length)) - } - if object { - panic("expected non-object terminal") } // Check if we have any deferred objects. - if count := len(ds.deferred); count > 0 { - // Shoud not happen, not propagated as an error. - panic(fmt.Sprintf("still have %d deferred objects", count)) - } - - // Scan and fire all callbacks. - for _, os := range ds.objectsByID { - os.checkComplete(ds.stats) + for id, encoded := range ds.deferred { + // Shoud never happen, the graph was bogus. + Failf("still have deferred objects: one is ID %d, %#v", id, encoded) } - // Check if we have any remaining dependency cycles. - for _, os := range ds.objectsByID { - if !os.complete() { - // This must be the result of a dependency cycle. - cycle := os.findCycle() - var buf bytes.Buffer - buf.WriteString("dependency cycle: {") - for i, cycleOS := range cycle { - if i > 0 { - buf.WriteString(" => ") + // Scan and fire all callbacks. We iterate over the list of incomplete + // objects until all have been finished. We stop iterating if no + // objects become complete (there is a dependency cycle). + // + // Note that we iterate backwards here, because there will be a strong + // tendendcy for blocking relationships to go from earlier objects to + // later (deeper) objects in the graph. This will reduce the number of + // iterations required to finish all objects. + if err := safely(func() { + for ds.pending.Back() != nil { + thisCycle := false + for ods = ds.pending.Back(); ods != nil; { + if ds.checkComplete(ods) { + thisCycle = true + break } - buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type())) + ods = ods.Prev() + } + if !thisCycle { + break } - buf.WriteString("}") - // Panic as an error; propagate to the caller. - panic(errors.New(string(buf.Bytes()))) } - } -} - -type byteReader struct { - io.Reader -} - -// ReadByte implements io.ByteReader. -func (br byteReader) ReadByte() (byte, error) { - var b [1]byte - n, err := br.Reader.Read(b[:]) - if n > 0 { - return b[0], nil - } else if err != nil { - return 0, err - } else { - return 0, io.ErrUnexpectedEOF + }); err != nil { + Failf("error executing callbacks for %#v: %w", ods.obj.Interface(), err) + } + + // Check if we have any remaining dependency cycles. If there are any + // objects left in the pending list, then it must be due to a cycle. + if ods := ds.pending.Front(); ods != nil { + // This must be the result of a dependency cycle. + cycle := ods.findCycle() + var buf bytes.Buffer + buf.WriteString("dependency cycle: {") + for i, cycleOS := range cycle { + if i > 0 { + buf.WriteString(" => ") + } + fmt.Fprintf(&buf, "%q", cycleOS.obj.Type()) + } + buf.WriteString("}") + Failf("incomplete graph: %s", string(buf.Bytes())) } } @@ -565,45 +706,20 @@ func (br byteReader) ReadByte() (byte, error) { // Each object written to the statefile is prefixed with a header. See // WriteHeader for more information; these functions are exported to allow // non-state writes to the file to play nice with debugging tools. -func ReadHeader(r io.Reader) (length uint64, object bool, err error) { +func ReadHeader(r wire.Reader) (length uint64, object bool, err error) { // Read the header. - length, err = binary.ReadUvarint(byteReader{r}) + err = safely(func() { + length = wire.LoadUint(r) + }) if err != nil { - return + // On the header, pass raw I/O errors. + if sErr, ok := err.(*ErrState); ok { + return 0, false, sErr.Unwrap() + } } // Decode whether the object is valid. - object = length&0x1 != 0 - length = length >> 1 + object = length&objectFlag != 0 + length &^= objectFlag return } - -// readObject reads an object from the stream. -func (ds *decodeState) readObject() (*pb.Object, error) { - // Read the header. - length, object, err := ReadHeader(ds.r) - if err != nil { - return nil, err - } - if !object { - return nil, fmt.Errorf("invalid object header") - } - - // Read the object. - buf := make([]byte, length) - for done := 0; done < len(buf); { - n, err := ds.r.Read(buf[done:]) - done += n - if n == 0 && err != nil { - return nil, err - } - } - - // Unmarshal. - obj := new(pb.Object) - if err := proto.Unmarshal(buf, obj); err != nil { - return nil, err - } - - return obj, nil -} diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go new file mode 100644 index 000000000..d048f61a1 --- /dev/null +++ b/pkg/state/decode_unsafe.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "reflect" + "unsafe" +) + +// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on +// values representing unexported fields. This bypasses visibility, but not +// type safety. +func unsafePointerTo(obj reflect.Value) reflect.Value { + return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr())) +} diff --git a/pkg/state/encode.go b/pkg/state/encode.go index c5118d3a9..92fcad4e9 100644 --- a/pkg/state/encode.go +++ b/pkg/state/encode.go @@ -15,437 +15,797 @@ package state import ( - "container/list" "context" - "encoding/binary" - "fmt" - "io" "reflect" - "sort" - "github.com/golang/protobuf/proto" - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" + "gvisor.dev/gvisor/pkg/state/wire" ) -// queuedObject is an object queued for encoding. -type queuedObject struct { - id uint64 - obj reflect.Value - path recoverable +// objectEncodeState the type and identity of an object occupying a memory +// address range. This is the value type for addrSet, and the intrusive entry +// for the pending and deferred lists. +type objectEncodeState struct { + // id is the assigned ID for this object. + id objectID + + // obj is the object value. Note that this may be replaced if we + // encounter an object that contains this object. When this happens (in + // resolve), we will update existing references approprately, below, + // and defer a re-encoding of the object. + obj reflect.Value + + // encoded is the encoded value of this object. Note that this may not + // be up to date if this object is still in the deferred list. + encoded wire.Object + + // how indicates whether this object should be encoded as a value. This + // is used only for deferred encoding. + how encodeStrategy + + // refs are the list of reference objects used by other objects + // referring to this object. When the object is updated, these + // references may be updated directly and automatically. + refs []*wire.Ref + + pendingEntry + deferredEntry } // encodeState is state used for encoding. // -// The encoding process is a breadth-first traversal of the object graph. The -// inherent races and dependencies are much simpler than the decode case. +// The encoding process constructs a representation of the in-memory graph of +// objects before a single object is serialized. This is done to ensure that +// all references can be fully disambiguated. See resolve for more details. type encodeState struct { // ctx is the encode context. ctx context.Context - // lastID is the last object ID. - // - // See idsByObject for context. Because of the special zero encoding - // used for reference values, the first ID must be 1. - lastID uint64 + // w is the output stream. + w wire.Writer - // idsByObject is a set of objects, indexed via: - // - // reflect.ValueOf(x).UnsafeAddr - // - // This provides IDs for objects. - idsByObject map[uintptr]uint64 + // types is the type database. + types typeEncodeDatabase + + // lastID is the last allocated object ID. + lastID objectID - // values stores values that span the addresses. + // values tracks the address ranges occupied by objects, along with the + // types of these objects. This is used to locate pointer targets, + // including pointers to fields within another type. // - // addrSet is a a generated type which efficiently stores ranges of - // addresses. When encoding pointers, these ranges are filled in and - // used to check for overlapping or conflicting pointers. This would - // indicate a pointer to an field, or a non-type safe value, neither of - // which are currently decodable. + // Multiple objects may overlap in memory iff the larger object fully + // contains the smaller one, and the type of the smaller object matches + // a field or array element's type at the appropriate offset. An + // arbitrary number of objects may be nested in this manner. // - // See the usage of values below for more context. + // Note that this does not track zero-sized objects, those are tracked + // by zeroValues below. values addrSet - // w is the output stream. - w io.Writer + // zeroValues tracks zero-sized objects. + zeroValues map[reflect.Type]*objectEncodeState - // pending is the list of objects to be serialized. - // - // This is a set of queuedObjects. - pending list.List + // deferred is the list of objects to be encoded. + deferred deferredList - // done is the a list of finished objects. - // - // This is kept to prevent garbage collection and address reuse. - done list.List + // pendingTypes is the list of types to be serialized. Serialization + // will occur when all objects have been encoded, but before pending is + // serialized. + pendingTypes []wire.Type - // stats is the passed stats object. - stats *Stats + // pending is the list of objects to be serialized. Serialization does + // not actually occur until the full object graph is computed. + pending pendingList - // recoverable is the panic recover facility. - recoverable + // stats tracks time data. + stats Stats } -// register looks up an ID, registering if necessary. +// isSameSizeParent returns true if child is a field value or element within +// parent. Only a struct or array can have a child value. +// +// isSameSizeParent deals with objects like this: +// +// struct child { +// // fields.. +// } // -// If the object was not previously registered, it is enqueued to be serialized. -// See the documentation for idsByObject for more information. -func (es *encodeState) register(obj reflect.Value) uint64 { - // It is not legal to call register for any non-pointer objects (see - // below), so we panic with a recoverable error if this is a mismatch. - if obj.Kind() != reflect.Ptr && obj.Kind() != reflect.Map { - panic(fmt.Errorf("non-pointer %#v registered", obj.Interface())) +// struct parent { +// c child +// } +// +// var p parent +// record(&p.c) +// +// Here, &p and &p.c occupy the exact same address range. +// +// Or like this: +// +// struct child { +// // fields +// } +// +// var arr [1]parent +// record(&arr[0]) +// +// Similarly, &arr[0] and &arr[0].c have the exact same address range. +// +// Precondition: parent and child must occupy the same memory. +func isSameSizeParent(parent reflect.Value, childType reflect.Type) bool { + switch parent.Kind() { + case reflect.Struct: + for i := 0; i < parent.NumField(); i++ { + field := parent.Field(i) + if field.Type() == childType { + return true + } + // Recurse through any intermediate types. + if isSameSizeParent(field, childType) { + return true + } + // Does it make sense to keep going if the first field + // doesn't match? Yes, because there might be an + // arbitrary number of zero-sized fields before we get + // a match, and childType itself can be zero-sized. + } + return false + case reflect.Array: + // The only case where an array with more than one elements can + // return true is if childType is zero-sized. In such cases, + // it's ambiguous which element contains the match since a + // zero-sized child object fully fits in any of the zero-sized + // elements in an array... However since all elements are of + // the same type, we only need to check one element. + // + // For non-zero-sized childTypes, parent.Len() must be 1, but a + // combination of the precondition and an implicit comparison + // between the array element size and childType ensures this. + return parent.Len() > 0 && isSameSizeParent(parent.Index(0), childType) + default: + return false } +} - addr := obj.Pointer() - if obj.Kind() == reflect.Ptr && obj.Elem().Type().Size() == 0 { - // For zero-sized objects, we always provide a unique ID. - // That's because the runtime internally multiplexes pointers - // to the same address. We can't be certain what the intent is - // with pointers to zero-sized objects, so we just give them - // all unique identities. - } else if id, ok := es.idsByObject[addr]; ok { - // Already registered. - return id - } - - // Ensure that the first ID given out is one. See note on lastID. The - // ID zero is used to indicate nil values. +// nextID returns the next valid ID. +func (es *encodeState) nextID() objectID { es.lastID++ - id := es.lastID - es.idsByObject[addr] = id - if obj.Kind() == reflect.Ptr { - // Dereference and treat as a pointer. - es.pending.PushBack(queuedObject{id: id, obj: obj.Elem(), path: es.recoverable.copy()}) - - // Register this object at all addresses. - typ := obj.Elem().Type() - if size := typ.Size(); size > 0 { - r := addrRange{addr, addr + size} - if !es.values.IsEmptyRange(r) { - old := es.values.LowerBoundSegment(addr).Value().Interface().(recoverable) - panic(fmt.Errorf("overlapping objects: [new object] %#v [existing object path] %s", obj.Interface(), old.path())) + return objectID(es.lastID) +} + +// dummyAddr points to the dummy zero-sized address. +var dummyAddr = reflect.ValueOf(new(struct{})).Pointer() + +// resolve records the address range occupied by an object. +func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { + addr := obj.Pointer() + + // Is this a map pointer? Just record the single address. It is not + // possible to take any pointers into the map internals. + if obj.Kind() == reflect.Map { + if addr == 0 { + // Just leave the nil reference alone. This is fine, we + // may need to encode as a reference in this way. We + // return nil for our objectEncodeState so that anyone + // depending on this value knows there's nothing there. + return + } + if seg, _ := es.values.Find(addr); seg.Ok() { + // Ensure the map types match. + existing := seg.Value() + if existing.obj.Type() != obj.Type() { + Failf("overlapping map objects at 0x%x: [new object] %#v [existing object type] %s", addr, obj, existing.obj) } - es.values.Add(r, reflect.ValueOf(es.recoverable.copy())) + + // No sense recording refs, maps may not be replaced by + // covering objects, they are maximal. + ref.Root = wire.Uint(existing.id) + return } + + // Record the map. + oes := &objectEncodeState{ + id: es.nextID(), + obj: obj, + how: encodeMapAsValue, + } + es.values.Add(addrRange{addr, addr + 1}, oes) + es.pending.PushBack(oes) + es.deferred.PushBack(oes) + + // See above: no ref recording. + ref.Root = wire.Uint(oes.id) + return + } + + // If not a map, then the object must be a pointer. + if obj.Kind() != reflect.Ptr { + Failf("attempt to record non-map and non-pointer object %#v", obj) + } + + obj = obj.Elem() // Value from here. + + // Is this a zero-sized type? + typ := obj.Type() + size := typ.Size() + if size == 0 { + if addr == dummyAddr { + // Zero-sized objects point to a dummy byte within the + // runtime. There's no sense recording this in the + // address map. We add this to the dedicated + // zeroValues. + // + // Note that zero-sized objects must be *true* + // zero-sized objects. They cannot be part of some + // larger object. In that case, they are assigned a + // 1-byte address at the end of the object. + oes, ok := es.zeroValues[typ] + if !ok { + oes = &objectEncodeState{ + id: es.nextID(), + obj: obj, + } + es.zeroValues[typ] = oes + es.pending.PushBack(oes) + es.deferred.PushBack(oes) + } + + // There's also no sense tracking back references. We + // know that this is a true zero-sized object, and not + // part of a larger container, so it will not change. + ref.Root = wire.Uint(oes.id) + return + } + size = 1 // See above. + } + + // Calculate the container. + end := addr + size + r := addrRange{addr, end} + if seg, _ := es.values.Find(addr); seg.Ok() { + existing := seg.Value() + switch { + case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type(): + // The object is a perfect match. Happy path. Avoid the + // traversal and just return directly. We don't need to + // encode the type information or any dots here. + ref.Root = wire.Uint(existing.id) + existing.refs = append(existing.refs, ref) + return + + case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end): + // The previously registered object is larger than + // this, no need to update. But we expect some + // traversal below. + + case seg.Start() == addr && seg.End() == end: + if !isSameSizeParent(obj, existing.obj.Type()) { + break // Needs traversal. + } + fallthrough // Needs update. + + case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end): + // Update the object and redo the encoding. + old := existing.obj + existing.obj = obj + es.deferred.Remove(existing) + es.deferred.PushBack(existing) + + // The previously registered object is superseded by + // this new object. We are guaranteed to not have any + // mergeable neighbours in this segment set. + if !raceEnabled { + seg.SetRangeUnchecked(r) + } else { + // Add extra paranoid. This will be statically + // removed at compile time unless a race build. + es.values.Remove(seg) + es.values.Add(r, existing) + seg = es.values.LowerBoundSegment(addr) + } + + // Compute the traversal required & update references. + dots := traverse(obj.Type(), old.Type(), addr, seg.Start()) + wt := es.findType(obj.Type()) + for _, ref := range existing.refs { + ref.Dots = append(ref.Dots, dots...) + ref.Type = wt + } + default: + // There is a non-sensical overlap. + Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj) + } + + // Compute the new reference, record and return it. + ref.Root = wire.Uint(existing.id) + ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr) + ref.Type = es.findType(obj.Type()) + existing.refs = append(existing.refs, ref) + return + } + + // The only remaining case is a pointer value that doesn't overlap with + // any registered addresses. Create a new entry for it, and start + // tracking the first reference we just created. + oes := &objectEncodeState{ + id: es.nextID(), + obj: obj, + } + if !raceEnabled { + es.values.AddWithoutMerging(r, oes) } else { - // Push back the map itself; when maps are encoded from the - // top-level, forceMap will be equal to true. - es.pending.PushBack(queuedObject{id: id, obj: obj, path: es.recoverable.copy()}) + // Merges should never happen. This is just enabled extra + // sanity checks because the Merge function below will panic. + es.values.Add(r, oes) + } + es.pending.PushBack(oes) + es.deferred.PushBack(oes) + ref.Root = wire.Uint(oes.id) + oes.refs = append(oes.refs, ref) +} + +// traverse searches for a target object within a root object, where the target +// object is a struct field or array element within root, with potentially +// multiple intervening types. traverse returns the set of field or element +// traversals required to reach the target. +// +// Note that for efficiency, traverse returns the dots in the reverse order. +// That is, the first traversal required will be the last element of the list. +// +// Precondition: The target object must lie completely within the range defined +// by [rootAddr, rootAddr + sizeof(rootType)]. +func traverse(rootType, targetType reflect.Type, rootAddr, targetAddr uintptr) []wire.Dot { + // Recursion base case: the types actually match. + if targetType == rootType && targetAddr == rootAddr { + return nil } - return id + switch rootType.Kind() { + case reflect.Struct: + offset := targetAddr - rootAddr + for i := rootType.NumField(); i > 0; i-- { + field := rootType.Field(i - 1) + // The first field from the end with an offset that is + // smaller than or equal to our address offset is where + // the target is located. Traverse from there. + if field.Offset <= offset { + dots := traverse(field.Type, targetType, rootAddr+field.Offset, targetAddr) + fieldName := wire.FieldName(field.Name) + return append(dots, &fieldName) + } + } + // Should never happen; the target should be reachable. + Failf("no field in root type %v contains target type %v", rootType, targetType) + + case reflect.Array: + // Since arrays have homogenous types, all elements have the + // same size and we can compute where the target lives. This + // does not matter for the purpose of typing, but matters for + // the purpose of computing the address of the given index. + elemSize := int(rootType.Elem().Size()) + n := int(targetAddr-rootAddr) / elemSize // Relies on integer division rounding down. + if rootType.Len() < n { + Failf("traversal target of type %v @%x is beyond the end of the array type %v @%x with %v elements", + targetType, targetAddr, rootType, rootAddr, rootType.Len()) + } + dots := traverse(rootType.Elem(), targetType, rootAddr+uintptr(n*elemSize), targetAddr) + return append(dots, wire.Index(n)) + + default: + // For any other type, there's no possibility of aliasing so if + // the types didn't match earlier then we have an addresss + // collision which shouldn't be possible at this point. + Failf("traverse failed for root type %v and target type %v", rootType, targetType) + } + panic("unreachable") } // encodeMap encodes a map. -func (es *encodeState) encodeMap(obj reflect.Value) *pb.Map { - var ( - keys []*pb.Object - values []*pb.Object - ) +func (es *encodeState) encodeMap(obj reflect.Value, dest *wire.Object) { + if obj.IsNil() { + // Because there is a difference between a nil map and an empty + // map, we need to not decode in the case of a truly nil map. + *dest = wire.Nil{} + return + } + l := obj.Len() + m := &wire.Map{ + Keys: make([]wire.Object, l), + Values: make([]wire.Object, l), + } + *dest = m for i, k := range obj.MapKeys() { v := obj.MapIndex(k) - kp := es.encodeObject(k, false, ".(key %d)", i) - vp := es.encodeObject(v, false, "[%#v]", k.Interface()) - keys = append(keys, kp) - values = append(values, vp) + // Map keys must be encoded using the full value because the + // type will be omitted after the first key. + es.encodeObject(k, encodeAsValue, &m.Keys[i]) + es.encodeObject(v, encodeAsValue, &m.Values[i]) } - return &pb.Map{Keys: keys, Values: values} +} + +// objectEncoder is for encoding structs. +type objectEncoder struct { + // es is encodeState. + es *encodeState + + // encoded is the encoded struct. + encoded *wire.Struct +} + +// save is called by the public methods on Sink. +func (oe *objectEncoder) save(slot int, obj reflect.Value) { + fieldValue := oe.encoded.Field(slot) + oe.es.encodeObject(obj, encodeDefault, fieldValue) } // encodeStruct encodes a composite object. -func (es *encodeState) encodeStruct(obj reflect.Value) *pb.Struct { - // Invoke the save. - m := Map{newInternalMap(es, nil, nil)} - defer internalMapPool.Put(m.internalMap) +func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) { + // Ensure that the obj is addressable. There are two cases when it is + // not. First, is when this is dispatched via SaveValue. Second, when + // this is a map key as a struct. Either way, we need to make a copy to + // obtain an addressable value. if !obj.CanAddr() { - // Force it to a * type of the above; this involves a copy. localObj := reflect.New(obj.Type()) localObj.Elem().Set(obj) obj = localObj.Elem() } - fns, ok := registeredTypes.lookupFns(obj.Addr().Type()) - if ok { - // Invoke the provided saver. - fns.invokeSave(obj.Addr(), m) - } else if obj.NumField() == 0 { - // Allow unregistered anonymous, empty structs. - return &pb.Struct{} - } else { - // Propagate an error. - panic(fmt.Errorf("unregistered type %T", obj.Interface())) - } - - // Sort the underlying slice, and check for duplicates. This is done - // once instead of on each add, because performing this sort once is - // far more efficient. - if len(m.data) > 1 { - sort.Slice(m.data, func(i, j int) bool { - return m.data[i].name < m.data[j].name - }) - for i := range m.data { - if i > 0 && m.data[i-1].name == m.data[i].name { - panic(fmt.Errorf("duplicate name %s", m.data[i].name)) - } + + // Prepare the value. + s := &wire.Struct{} + *dest = s + + // Look the type up in the database. + te, ok := es.types.Lookup(obj.Type()) + if te == nil { + if obj.NumField() == 0 { + // Allow unregistered anonymous, empty structs. This + // will just return success without ever invoking the + // passed function. This uses the immutable EmptyStruct + // variable to prevent an allocation in this case. + // + // Note that this mechanism does *not* work for + // interfaces in general. So you can't dispatch + // non-registered empty structs via interfaces because + // then they can't be restored. + s.Alloc(0) + return } + // We need a SaverLoader for struct types. + Failf("struct %T does not implement SaverLoader", obj.Interface()) } - - // Encode the resulting fields. - fields := make([]*pb.Field, 0, len(m.data)) - for _, e := range m.data { - fields = append(fields, &pb.Field{ - Name: e.name, - Value: e.object, - }) + if !ok { + // Queue the type to be serialized. + es.pendingTypes = append(es.pendingTypes, te.Type) } - // Return the encoded object. - return &pb.Struct{Fields: fields} + // Invoke the provided saver. + s.TypeID = wire.TypeID(te.ID) + s.Alloc(len(te.Fields)) + oe := objectEncoder{ + es: es, + encoded: s, + } + es.stats.start(te.ID) + defer es.stats.done() + if sl, ok := obj.Addr().Interface().(SaverLoader); ok { + // Note: may be a registered empty struct which does not + // implement the saver/loader interfaces. + sl.StateSave(Sink{internal: oe}) + } } // encodeArray encodes an array. -func (es *encodeState) encodeArray(obj reflect.Value) *pb.Array { - var ( - contents []*pb.Object - ) - for i := 0; i < obj.Len(); i++ { - entry := es.encodeObject(obj.Index(i), false, "[%d]", i) - contents = append(contents, entry) - } - return &pb.Array{Contents: contents} +func (es *encodeState) encodeArray(obj reflect.Value, dest *wire.Object) { + l := obj.Len() + a := &wire.Array{ + Contents: make([]wire.Object, l), + } + *dest = a + for i := 0; i < l; i++ { + // We need to encode the full value because arrays are encoded + // using the type information from only the first element. + es.encodeObject(obj.Index(i), encodeAsValue, &a.Contents[i]) + } +} + +// findType recursively finds type information. +func (es *encodeState) findType(typ reflect.Type) wire.TypeSpec { + // First: check if this is a proper type. It's possible for pointers, + // slices, arrays, maps, etc to all have some different type. + te, ok := es.types.Lookup(typ) + if te != nil { + if !ok { + // See encodeStruct. + es.pendingTypes = append(es.pendingTypes, te.Type) + } + return wire.TypeID(te.ID) + } + + switch typ.Kind() { + case reflect.Ptr: + return &wire.TypeSpecPointer{ + Type: es.findType(typ.Elem()), + } + case reflect.Slice: + return &wire.TypeSpecSlice{ + Type: es.findType(typ.Elem()), + } + case reflect.Array: + return &wire.TypeSpecArray{ + Count: wire.Uint(typ.Len()), + Type: es.findType(typ.Elem()), + } + case reflect.Map: + return &wire.TypeSpecMap{ + Key: es.findType(typ.Key()), + Value: es.findType(typ.Elem()), + } + default: + // After potentially chasing many pointers, the + // ultimate type of the object is not known. + Failf("type %q is not known", typ) + } + panic("unreachable") } // encodeInterface encodes an interface. -// -// Precondition: the value is not nil. -func (es *encodeState) encodeInterface(obj reflect.Value) *pb.Interface { - // Check for the nil interface. - obj = reflect.ValueOf(obj.Interface()) +func (es *encodeState) encodeInterface(obj reflect.Value, dest *wire.Object) { + // Dereference the object. + obj = obj.Elem() if !obj.IsValid() { - return &pb.Interface{ - Type: "", // left alone in decode. - Value: &pb.Object{Value: &pb.Object_RefValue{0}}, + // Special case: the nil object. + *dest = &wire.Interface{ + Type: wire.TypeSpecNil{}, + Value: wire.Nil{}, } + return } - // We have an interface value here. How do we save that? We - // resolve the underlying type and save it as a dispatchable. - typName, ok := registeredTypes.lookupName(obj.Type()) - if !ok { - panic(fmt.Errorf("type %s is not registered", obj.Type())) + + // Encode underlying object. + i := &wire.Interface{ + Type: es.findType(obj.Type()), } + *dest = i + es.encodeObject(obj, encodeAsValue, &i.Value) +} - // Encode the object again. - return &pb.Interface{ - Type: typName, - Value: es.encodeObject(obj, false, ".(%s)", typName), +// isPrimitive returns true if this is a primitive object, or a composite +// object composed entirely of primitives. +func isPrimitiveZero(typ reflect.Type) bool { + switch typ.Kind() { + case reflect.Ptr: + // Pointers are always treated as primitive types because we + // won't encode directly from here. Returning true here won't + // prevent the object from being encoded correctly. + return true + case reflect.Bool: + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return true + case reflect.Float32, reflect.Float64: + return true + case reflect.Complex64, reflect.Complex128: + return true + case reflect.String: + return true + case reflect.Slice: + // The slice itself a primitive, but not necessarily the array + // that points to. This is similar to a pointer. + return true + case reflect.Array: + // We cannot treat an array as a primitive, because it may be + // composed of structures or other things with side-effects. + return isPrimitiveZero(typ.Elem()) + case reflect.Interface: + // Since we now that this type is the zero type, the interface + // value must be zero. Therefore this is primitive. + return true + case reflect.Struct: + return false + case reflect.Map: + // The isPrimitiveZero function is called only on zero-types to + // see if it's safe to serialize. Since a zero map has no + // elements, it is safe to treat as a primitive. + return true + default: + Failf("unknown type %q", typ.Name()) } + panic("unreachable") } -// encodeObject encodes an object. -// -// If mapAsValue is true, then a map will be encoded directly. -func (es *encodeState) encodeObject(obj reflect.Value, mapAsValue bool, format string, param interface{}) (object *pb.Object) { - es.push(false, format, param) - es.stats.Add(obj) - es.stats.Start(obj) +// encodeStrategy is the strategy used for encodeObject. +type encodeStrategy int +const ( + // encodeDefault means types are encoded normally as references. + encodeDefault encodeStrategy = iota + + // encodeAsValue means that types will never take short-circuited and + // will always be encoded as a normal value. + encodeAsValue + + // encodeMapAsValue means that even maps will be fully encoded. + encodeMapAsValue +) + +// encodeObject encodes an object. +func (es *encodeState) encodeObject(obj reflect.Value, how encodeStrategy, dest *wire.Object) { + if how == encodeDefault && isPrimitiveZero(obj.Type()) && obj.IsZero() { + *dest = wire.Nil{} + return + } switch obj.Kind() { + case reflect.Ptr: // Fast path: first. + r := new(wire.Ref) + *dest = r + if obj.IsNil() { + // May be in an array or elsewhere such that a value is + // required. So we encode as a reference to the zero + // object, which does not exist. Note that this has to + // be handled correctly in the decode path as well. + return + } + es.resolve(obj, r) case reflect.Bool: - object = &pb.Object{Value: &pb.Object_BoolValue{obj.Bool()}} + *dest = wire.Bool(obj.Bool()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - object = &pb.Object{Value: &pb.Object_Int64Value{obj.Int()}} + *dest = wire.Int(obj.Int()) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - object = &pb.Object{Value: &pb.Object_Uint64Value{obj.Uint()}} - case reflect.Float32, reflect.Float64: - object = &pb.Object{Value: &pb.Object_DoubleValue{obj.Float()}} + *dest = wire.Uint(obj.Uint()) + case reflect.Float32: + *dest = wire.Float32(obj.Float()) + case reflect.Float64: + *dest = wire.Float64(obj.Float()) + case reflect.Complex64: + c := wire.Complex64(obj.Complex()) + *dest = &c // Needs alloc. + case reflect.Complex128: + c := wire.Complex128(obj.Complex()) + *dest = &c // Needs alloc. + case reflect.String: + s := wire.String(obj.String()) + *dest = &s // Needs alloc. case reflect.Array: - switch obj.Type().Elem().Kind() { - case reflect.Uint8: - object = &pb.Object{Value: &pb.Object_ByteArrayValue{pbSlice(obj).Interface().([]byte)}} - case reflect.Uint16: - // 16-bit slices are serialized as 32-bit slices. - // See object.proto for details. - s := pbSlice(obj).Interface().([]uint16) - t := make([]uint32, len(s)) - for i := range s { - t[i] = uint32(s[i]) - } - object = &pb.Object{Value: &pb.Object_Uint16ArrayValue{&pb.Uint16S{Values: t}}} - case reflect.Uint32: - object = &pb.Object{Value: &pb.Object_Uint32ArrayValue{&pb.Uint32S{Values: pbSlice(obj).Interface().([]uint32)}}} - case reflect.Uint64: - object = &pb.Object{Value: &pb.Object_Uint64ArrayValue{&pb.Uint64S{Values: pbSlice(obj).Interface().([]uint64)}}} - case reflect.Uintptr: - object = &pb.Object{Value: &pb.Object_UintptrArrayValue{&pb.Uintptrs{Values: pbSlice(obj).Interface().([]uint64)}}} - case reflect.Int8: - object = &pb.Object{Value: &pb.Object_Int8ArrayValue{&pb.Int8S{Values: pbSlice(obj).Interface().([]byte)}}} - case reflect.Int16: - // 16-bit slices are serialized as 32-bit slices. - // See object.proto for details. - s := pbSlice(obj).Interface().([]int16) - t := make([]int32, len(s)) - for i := range s { - t[i] = int32(s[i]) - } - object = &pb.Object{Value: &pb.Object_Int16ArrayValue{&pb.Int16S{Values: t}}} - case reflect.Int32: - object = &pb.Object{Value: &pb.Object_Int32ArrayValue{&pb.Int32S{Values: pbSlice(obj).Interface().([]int32)}}} - case reflect.Int64: - object = &pb.Object{Value: &pb.Object_Int64ArrayValue{&pb.Int64S{Values: pbSlice(obj).Interface().([]int64)}}} - case reflect.Bool: - object = &pb.Object{Value: &pb.Object_BoolArrayValue{&pb.Bools{Values: pbSlice(obj).Interface().([]bool)}}} - case reflect.Float32: - object = &pb.Object{Value: &pb.Object_Float32ArrayValue{&pb.Float32S{Values: pbSlice(obj).Interface().([]float32)}}} - case reflect.Float64: - object = &pb.Object{Value: &pb.Object_Float64ArrayValue{&pb.Float64S{Values: pbSlice(obj).Interface().([]float64)}}} - default: - object = &pb.Object{Value: &pb.Object_ArrayValue{es.encodeArray(obj)}} - } + es.encodeArray(obj, dest) case reflect.Slice: - if obj.IsNil() || obj.Cap() == 0 { - // Handled specially in decode; store as nil value. - object = &pb.Object{Value: &pb.Object_RefValue{0}} - } else { - // Serialize a slice as the array plus length and capacity. - object = &pb.Object{Value: &pb.Object_SliceValue{&pb.Slice{ - Capacity: uint32(obj.Cap()), - Length: uint32(obj.Len()), - RefValue: es.register(arrayFromSlice(obj)), - }}} + s := &wire.Slice{ + Capacity: wire.Uint(obj.Cap()), + Length: wire.Uint(obj.Len()), } - case reflect.String: - object = &pb.Object{Value: &pb.Object_StringValue{[]byte(obj.String())}} - case reflect.Ptr: + *dest = s + // Note that we do need to provide a wire.Slice type here as + // how is not encodeDefault. If this were the case, then it + // would have been caught by the IsZero check above and we + // would have just used wire.Nil{}. if obj.IsNil() { - // Handled specially in decode; store as a nil value. - object = &pb.Object{Value: &pb.Object_RefValue{0}} - } else { - es.push(true /* dereference */, "", nil) - object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}} - es.pop() + return } + // Slices need pointer resolution. + es.resolve(arrayFromSlice(obj), &s.Ref) case reflect.Interface: - // We don't check for IsNil here, as we want to encode type - // information. The case of the empty interface (no type, no - // value) is handled by encodeInteface. - object = &pb.Object{Value: &pb.Object_InterfaceValue{es.encodeInterface(obj)}} + es.encodeInterface(obj, dest) case reflect.Struct: - object = &pb.Object{Value: &pb.Object_StructValue{es.encodeStruct(obj)}} + es.encodeStruct(obj, dest) case reflect.Map: - if obj.IsNil() { - // Handled specially in decode; store as a nil value. - object = &pb.Object{Value: &pb.Object_RefValue{0}} - } else if mapAsValue { - // Encode the map directly. - object = &pb.Object{Value: &pb.Object_MapValue{es.encodeMap(obj)}} - } else { - // Encode a reference to the map. - // - // Remove the map object count here to avoid double - // counting, as this object will be counted again when - // it gets processed later. We do not add a reference - // count as the reference is artificial. - es.stats.Remove(obj) - object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}} + if how == encodeMapAsValue { + es.encodeMap(obj, dest) + return } + r := new(wire.Ref) + *dest = r + es.resolve(obj, r) default: - panic(fmt.Errorf("unknown primitive %#v", obj.Interface())) + Failf("unknown object %#v", obj.Interface()) + panic("unreachable") } - - es.stats.Done() - es.pop() - return } -// Serialize serializes the object state. -// -// This function may panic and should be run in safely(). -func (es *encodeState) Serialize(obj reflect.Value) { - es.register(obj.Addr()) - - // Pop off the list until we're done. - for es.pending.Len() > 0 { - e := es.pending.Front() - - // Extract the queued object. - qo := e.Value.(queuedObject) - es.stats.Start(qo.obj) +// Save serializes the object graph rooted at obj. +func (es *encodeState) Save(obj reflect.Value) { + es.stats.init() + defer es.stats.fini(func(id typeID) string { + return es.pendingTypes[id-1].Name + }) + + // Resolve the first object, which should queue a pile of additional + // objects on the pending list. All queued objects should be fully + // resolved, and we should be able to serialize after this call. + var root wire.Ref + es.resolve(obj.Addr(), &root) + + // Encode the graph. + var oes *objectEncodeState + if err := safely(func() { + for oes = es.deferred.Front(); oes != nil; oes = es.deferred.Front() { + // Remove and encode the object. Note that as a result + // of this encoding, the object may be enqueued on the + // deferred list yet again. That's expected, and why it + // is removed first. + es.deferred.Remove(oes) + es.encodeObject(oes.obj, oes.how, &oes.encoded) + } + }); err != nil { + // Include the object in the error message. + Failf("encoding error at object %#v: %w", oes.obj.Interface(), err) + } - es.pending.Remove(e) + // Check that items are pending. + if es.pending.Front() == nil { + Failf("pending is empty?") + } - es.from = &qo.path - o := es.encodeObject(qo.obj, true, "", nil) + // Write the header with the number of objects. Note that there is no + // way that es.lastID could conflict with objectID, which would + // indicate that an impossibly large encoding. + if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil { + Failf("error writing header: %w", err) + } - // Emit to our output stream. - if err := es.writeObject(qo.id, o); err != nil { - panic(err) + // Serialize all pending types and pending objects. Note that we don't + // bother removing from this list as we walk it because that just + // wastes time. It will not change after this point. + var id objectID + if err := safely(func() { + for _, wt := range es.pendingTypes { + // Encode the type. + wire.Save(es.w, &wt) } + for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() { + id++ // First object is 1. + if oes.id != id { + Failf("expected id %d, got %d", id, oes.id) + } - // Mark as done. - es.done.PushBack(e) - es.stats.Done() + // Marshall the object. + wire.Save(es.w, oes.encoded) + } + }); err != nil { + // Include the object and the error. + Failf("error serializing object %#v: %w", oes.encoded, err) } - // Write a zero-length terminal at the end; this is a sanity check - // applied at decode time as well (see decode.go). - if err := WriteHeader(es.w, 0, false); err != nil { - panic(err) + // Check what we wrote. + if id != es.lastID { + Failf("expected %d objects, wrote %d", es.lastID, id) } } +// objectFlag indicates that the length is a # of objects, rather than a raw +// byte length. When this is set on a length header in the stream, it may be +// decoded appropriately. +const objectFlag uint64 = 1 << 63 + // WriteHeader writes a header. // // Each object written to the statefile should be prefixed with a header. In // order to generate statefiles that play nicely with debugging tools, raw // writes should be prefixed with a header with object set to false and the // appropriate length. This will allow tools to skip these regions. -func WriteHeader(w io.Writer, length uint64, object bool) error { - // The lowest-order bit encodes whether this is a valid object. This is - // a purely internal convention, but allows the object flag to be - // returned from ReadHeader. - length = length << 1 +func WriteHeader(w wire.Writer, length uint64, object bool) error { + // Sanity check the length. + if length&objectFlag != 0 { + Failf("impossibly huge length: %d", length) + } if object { - length |= 0x1 + length |= objectFlag } // Write a header. - var hdr [32]byte - encodedLen := binary.PutUvarint(hdr[:], length) - for done := 0; done < encodedLen; { - n, err := w.Write(hdr[done:encodedLen]) - done += n - if n == 0 && err != nil { - return err - } - } - - return nil + return safely(func() { + wire.SaveUint(w, length) + }) } -// writeObject writes an object to the stream. -func (es *encodeState) writeObject(id uint64, obj *pb.Object) error { - // Marshal the proto. - buf, err := proto.Marshal(obj) - if err != nil { - return err - } +// pendingMapper is for the pending list. +type pendingMapper struct{} - // Write the object header. - if err := WriteHeader(es.w, uint64(len(buf)), true); err != nil { - return err - } +func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry } - // Write the object. - for done := 0; done < len(buf); { - n, err := es.w.Write(buf[done:]) - done += n - if n == 0 && err != nil { - return err - } - } +// deferredMapper is for the deferred list. +type deferredMapper struct{} - return nil -} +func (deferredMapper) linkerFor(oes *objectEncodeState) *deferredEntry { return &oes.deferredEntry } // addrSetFunctions is used by addrSet. type addrSetFunctions struct{} @@ -458,13 +818,24 @@ func (addrSetFunctions) MaxKey() uintptr { return ^uintptr(0) } -func (addrSetFunctions) ClearValue(val *reflect.Value) { +func (addrSetFunctions) ClearValue(val **objectEncodeState) { + *val = nil } -func (addrSetFunctions) Merge(_ addrRange, val1 reflect.Value, _ addrRange, val2 reflect.Value) (reflect.Value, bool) { - return val1, val1 == val2 +func (addrSetFunctions) Merge(r1 addrRange, val1 *objectEncodeState, r2 addrRange, val2 *objectEncodeState) (*objectEncodeState, bool) { + if val1.obj == val2.obj { + // This, should never happen. It would indicate that the same + // object exists in two non-contiguous address ranges. Note + // that this assertion can only be triggered if the race + // detector is enabled. + Failf("unexpected merge in addrSet @ %v and %v: %#v and %#v", r1, r2, val1.obj, val2.obj) + } + // Reject the merge. + return val1, false } -func (addrSetFunctions) Split(_ addrRange, val reflect.Value, _ uintptr) (reflect.Value, reflect.Value) { - return val, val +func (addrSetFunctions) Split(r addrRange, val *objectEncodeState, _ uintptr) (*objectEncodeState, *objectEncodeState) { + // A split should never happen: we don't remove ranges. + Failf("unexpected split in addrSet @ %v: %#v", r, val.obj) + panic("unreachable") } diff --git a/pkg/state/encode_unsafe.go b/pkg/state/encode_unsafe.go index 457e6dbb7..e0dad83b4 100644 --- a/pkg/state/encode_unsafe.go +++ b/pkg/state/encode_unsafe.go @@ -31,51 +31,3 @@ func arrayFromSlice(obj reflect.Value) reflect.Value { reflect.ArrayOf(obj.Cap(), obj.Type().Elem()), unsafe.Pointer(obj.Pointer())) } - -// pbSlice returns a protobuf-supported slice of the array and erase the -// original element type (which could be a defined type or non-supported type). -func pbSlice(obj reflect.Value) reflect.Value { - var typ reflect.Type - switch obj.Type().Elem().Kind() { - case reflect.Uint8: - typ = reflect.TypeOf(byte(0)) - case reflect.Uint16: - typ = reflect.TypeOf(uint16(0)) - case reflect.Uint32: - typ = reflect.TypeOf(uint32(0)) - case reflect.Uint64: - typ = reflect.TypeOf(uint64(0)) - case reflect.Uintptr: - typ = reflect.TypeOf(uint64(0)) - case reflect.Int8: - typ = reflect.TypeOf(byte(0)) - case reflect.Int16: - typ = reflect.TypeOf(int16(0)) - case reflect.Int32: - typ = reflect.TypeOf(int32(0)) - case reflect.Int64: - typ = reflect.TypeOf(int64(0)) - case reflect.Bool: - typ = reflect.TypeOf(bool(false)) - case reflect.Float32: - typ = reflect.TypeOf(float32(0)) - case reflect.Float64: - typ = reflect.TypeOf(float64(0)) - default: - panic("slice element is not of basic value type") - } - return reflect.NewAt( - reflect.ArrayOf(obj.Len(), typ), - unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()), - ).Elem().Slice(0, obj.Len()) -} - -func castSlice(obj reflect.Value, elemTyp reflect.Type) reflect.Value { - if obj.Type().Elem().Size() != elemTyp.Size() { - panic("cannot cast slice into other element type of different size") - } - return reflect.NewAt( - reflect.ArrayOf(obj.Len(), elemTyp), - unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()), - ).Elem() -} diff --git a/pkg/state/map.go b/pkg/state/map.go deleted file mode 100644 index 4f3ebb0da..000000000 --- a/pkg/state/map.go +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package state - -import ( - "context" - "fmt" - "reflect" - "sort" - "sync" - - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" -) - -// entry is a single map entry. -type entry struct { - name string - object *pb.Object -} - -// internalMap is the internal Map state. -// -// These are recycled via a pool to avoid churn. -type internalMap struct { - // es is encodeState. - es *encodeState - - // ds is decodeState. - ds *decodeState - - // os is current object being decoded. - // - // This will always be nil during encode. - os *objectState - - // data stores the encoded values. - data []entry -} - -var internalMapPool = sync.Pool{ - New: func() interface{} { - return new(internalMap) - }, -} - -// newInternalMap returns a cached map. -func newInternalMap(es *encodeState, ds *decodeState, os *objectState) *internalMap { - m := internalMapPool.Get().(*internalMap) - m.es = es - m.ds = ds - m.os = os - if m.data != nil { - m.data = m.data[:0] - } - return m -} - -// Map is a generic state container. -// -// This is the object passed to Save and Load in order to store their state. -// -// Detailed documentation is available in individual methods. -type Map struct { - *internalMap -} - -// Save adds the given object to the map. -// -// You should pass always pointers to the object you are saving. For example: -// -// type X struct { -// A int -// B *int -// } -// -// func (x *X) Save(m Map) { -// m.Save("A", &x.A) -// m.Save("B", &x.B) -// } -// -// func (x *X) Load(m Map) { -// m.Load("A", &x.A) -// m.Load("B", &x.B) -// } -func (m Map) Save(name string, objPtr interface{}) { - m.save(name, reflect.ValueOf(objPtr).Elem(), ".%s") -} - -// SaveValue adds the given object value to the map. -// -// This should be used for values where pointers are not available, or casts -// are required during Save/Load. -// -// For example, if we want to cast external package type P.Foo to int64: -// -// type X struct { -// A P.Foo -// } -// -// func (x *X) Save(m Map) { -// m.SaveValue("A", int64(x.A)) -// } -// -// func (x *X) Load(m Map) { -// m.LoadValue("A", new(int64), func(x interface{}) { -// x.A = P.Foo(x.(int64)) -// }) -// } -func (m Map) SaveValue(name string, obj interface{}) { - m.save(name, reflect.ValueOf(obj), ".(value %s)") -} - -// save is helper for the above. It takes the name of value to save the field -// to, the field object (obj), and a format string that specifies how the -// field's saving logic is dispatched from the struct (normal, value, etc.). The -// format string should expect one string parameter, which is the name of the -// field. -func (m Map) save(name string, obj reflect.Value, format string) { - if m.es == nil { - // Not currently encoding. - m.Failf("no encode state for %q", name) - } - - // Attempt the encode. - // - // These are sorted at the end, after all objects are added and will be - // sorted and checked for duplicates (see encodeStruct). - m.data = append(m.data, entry{ - name: name, - object: m.es.encodeObject(obj, false, format, name), - }) -} - -// Load loads the given object from the map. -// -// See Save for an example. -func (m Map) Load(name string, objPtr interface{}) { - m.load(name, reflect.ValueOf(objPtr), false, nil, ".%s") -} - -// LoadWait loads the given objects from the map, and marks it as requiring all -// AfterLoad executions to complete prior to running this object's AfterLoad. -// -// See Save for an example. -func (m Map) LoadWait(name string, objPtr interface{}) { - m.load(name, reflect.ValueOf(objPtr), true, nil, ".(wait %s)") -} - -// LoadValue loads the given object value from the map. -// -// See SaveValue for an example. -func (m Map) LoadValue(name string, objPtr interface{}, fn func(interface{})) { - o := reflect.ValueOf(objPtr) - m.load(name, o, true, func() { fn(o.Elem().Interface()) }, ".(value %s)") -} - -// load is helper for the above. It takes the name of value to load the field -// from, the target field pointer (objPtr), whether load completion of the -// struct depends on the field's load completion (wait), the load completion -// logic (fn), and a format string that specifies how the field's loading logic -// is dispatched from the struct (normal, wait, value, etc.). The format string -// should expect one string parameter, which is the name of the field. -func (m Map) load(name string, objPtr reflect.Value, wait bool, fn func(), format string) { - if m.ds == nil { - // Not currently decoding. - m.Failf("no decode state for %q", name) - } - - // Find the object. - // - // These are sorted up front (and should appear in the state file - // sorted as well), so we can do a binary search here to ensure that - // large structs don't behave badly. - i := sort.Search(len(m.data), func(i int) bool { - return m.data[i].name >= name - }) - if i >= len(m.data) || m.data[i].name != name { - // There is no data for this name? - m.Failf("no data found for %q", name) - } - - // Perform the decode. - m.ds.decodeObject(m.os, objPtr.Elem(), m.data[i].object, format, name) - if wait { - // Mark this individual object a blocker. - m.ds.waitObject(m.os, m.data[i].object, fn) - } -} - -// Failf fails the save or restore with the provided message. Processing will -// stop after calling Failf, as the state package uses a panic & recover -// mechanism for state errors. You should defer any cleanup required. -func (m Map) Failf(format string, args ...interface{}) { - panic(fmt.Errorf(format, args...)) -} - -// AfterLoad schedules a function execution when all objects have been allocated -// and their automated loading and customized load logic have been executed. fn -// will not be executed until all of current object's dependencies' AfterLoad() -// logic, if exist, have been executed. -func (m Map) AfterLoad(fn func()) { - if m.ds == nil { - // Not currently decoding. - m.Failf("not decoding") - } - - // Queue the local callback; this will execute when all of the above - // data dependencies have been cleared. - m.os.callbacks = append(m.os.callbacks, fn) -} - -// Context returns the current context object. -func (m Map) Context() context.Context { - if m.es != nil { - return m.es.ctx - } else if m.ds != nil { - return m.ds.ctx - } - return context.Background() // No context. -} diff --git a/pkg/state/object.proto b/pkg/state/object.proto deleted file mode 100644 index 5ebcfb151..000000000 --- a/pkg/state/object.proto +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor.state.statefile; - -// Slice is a slice value. -message Slice { - uint32 length = 1; - uint32 capacity = 2; - uint64 ref_value = 3; -} - -// Array is an array value. -message Array { - repeated Object contents = 1; -} - -// Map is a map value. -message Map { - repeated Object keys = 1; - repeated Object values = 2; -} - -// Interface is an interface value. -message Interface { - string type = 1; - Object value = 2; -} - -// Struct is a basic composite value. -message Struct { - repeated Field fields = 1; -} - -// Field encodes a single field. -message Field { - string name = 1; - Object value = 2; -} - -// Uint16s encodes an uint16 array. To be used inside oneof structure. -message Uint16s { - // There is no 16-bit type in protobuf so we use variable length 32-bit here. - repeated uint32 values = 1; -} - -// Uint32s encodes an uint32 array. To be used inside oneof structure. -message Uint32s { - repeated fixed32 values = 1; -} - -// Uint64s encodes an uint64 array. To be used inside oneof structure. -message Uint64s { - repeated fixed64 values = 1; -} - -// Uintptrs encodes an uintptr array. To be used inside oneof structure. -message Uintptrs { - repeated fixed64 values = 1; -} - -// Int8s encodes an int8 array. To be used inside oneof structure. -message Int8s { - bytes values = 1; -} - -// Int16s encodes an int16 array. To be used inside oneof structure. -message Int16s { - // There is no 16-bit type in protobuf so we use variable length 32-bit here. - repeated int32 values = 1; -} - -// Int32s encodes an int32 array. To be used inside oneof structure. -message Int32s { - repeated sfixed32 values = 1; -} - -// Int64s encodes an int64 array. To be used inside oneof structure. -message Int64s { - repeated sfixed64 values = 1; -} - -// Bools encodes a boolean array. To be used inside oneof structure. -message Bools { - repeated bool values = 1; -} - -// Float64s encodes a float64 array. To be used inside oneof structure. -message Float64s { - repeated double values = 1; -} - -// Float32s encodes a float32 array. To be used inside oneof structure. -message Float32s { - repeated float values = 1; -} - -// Object are primitive encodings. -// -// Note that ref_value references an Object.id, below. -message Object { - oneof value { - bool bool_value = 1; - bytes string_value = 2; - int64 int64_value = 3; - uint64 uint64_value = 4; - double double_value = 5; - uint64 ref_value = 6; - Slice slice_value = 7; - Array array_value = 8; - Interface interface_value = 9; - Struct struct_value = 10; - Map map_value = 11; - bytes byte_array_value = 12; - Uint16s uint16_array_value = 13; - Uint32s uint32_array_value = 14; - Uint64s uint64_array_value = 15; - Uintptrs uintptr_array_value = 16; - Int8s int8_array_value = 17; - Int16s int16_array_value = 18; - Int32s int32_array_value = 19; - Int64s int64_array_value = 20; - Bools bool_array_value = 21; - Float64s float64_array_value = 22; - Float32s float32_array_value = 23; - } -} diff --git a/pkg/state/pretty/BUILD b/pkg/state/pretty/BUILD new file mode 100644 index 000000000..d053802f7 --- /dev/null +++ b/pkg/state/pretty/BUILD @@ -0,0 +1,13 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "pretty", + srcs = ["pretty.go"], + visibility = ["//:sandbox"], + deps = [ + "//pkg/state", + "//pkg/state/wire", + ], +) diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go new file mode 100644 index 000000000..cf37aaa49 --- /dev/null +++ b/pkg/state/pretty/pretty.go @@ -0,0 +1,273 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package pretty is a pretty-printer for state streams. +package pretty + +import ( + "fmt" + "io" + "io/ioutil" + "reflect" + "strings" + + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/wire" +) + +func formatRef(x *wire.Ref, graph uint64, html bool) string { + baseRef := fmt.Sprintf("g%dr%d", graph, x.Root) + fullRef := baseRef + if len(x.Dots) > 0 { + // See wire.Ref; Type valid if Dots non-zero. + typ, _ := formatType(x.Type, graph, html) + var buf strings.Builder + buf.WriteString("(*") + buf.WriteString(typ) + buf.WriteString(")(") + buf.WriteString(baseRef) + for _, component := range x.Dots { + switch v := component.(type) { + case *wire.FieldName: + buf.WriteString(".") + buf.WriteString(string(*v)) + case wire.Index: + buf.WriteString(fmt.Sprintf("[%d]", v)) + default: + panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component))) + } + } + buf.WriteString(")") + fullRef = buf.String() + } + if html { + return fmt.Sprintf("<a href=\"#%s\">%s</a>", baseRef, fullRef) + } + return fullRef +} + +func formatType(t wire.TypeSpec, graph uint64, html bool) (string, bool) { + switch x := t.(type) { + case wire.TypeID: + base := fmt.Sprintf("g%dt%d", graph, x) + if html { + return fmt.Sprintf("<a href=\"#%s\">%s</a>", base, base), true + } + return fmt.Sprintf("%s", base), true + case wire.TypeSpecNil: + return "", false // Only nil type. + case *wire.TypeSpecPointer: + element, _ := formatType(x.Type, graph, html) + return fmt.Sprintf("(*%s)", element), true + case *wire.TypeSpecArray: + element, _ := formatType(x.Type, graph, html) + return fmt.Sprintf("[%d](%s)", x.Count, element), true + case *wire.TypeSpecSlice: + element, _ := formatType(x.Type, graph, html) + return fmt.Sprintf("([]%s)", element), true + case *wire.TypeSpecMap: + key, _ := formatType(x.Key, graph, html) + value, _ := formatType(x.Value, graph, html) + return fmt.Sprintf("(map[%s]%s)", key, value), true + default: + panic(fmt.Sprintf("unreachable: unknown type %T", t)) + } +} + +// format formats a single object, for pretty-printing. It also returns whether +// the value is a non-zero value. +func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bool) { + switch x := encoded.(type) { + case wire.Nil: + return "nil", false + case *wire.String: + return fmt.Sprintf("%q", *x), *x != "" + case *wire.Complex64: + return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0 + case *wire.Complex128: + return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0 + case *wire.Ref: + return formatRef(x, graph, html), x.Root != 0 + case *wire.Type: + tabs := "\n" + strings.Repeat("\t", depth) + items := make([]string, 0, len(x.Fields)+2) + items = append(items, fmt.Sprintf("type %s {", x.Name)) + for i := 0; i < len(x.Fields); i++ { + items = append(items, fmt.Sprintf("\t%d: %s,", i, x.Fields[i])) + } + items = append(items, "}") + return strings.Join(items, tabs), true // No zero value. + case *wire.Slice: + return fmt.Sprintf("%s{len:%d,cap:%d}", formatRef(&x.Ref, graph, html), x.Length, x.Capacity), x.Capacity != 0 + case *wire.Array: + if len(x.Contents) == 0 { + return "[]", false + } + items := make([]string, 0, len(x.Contents)+2) + zeros := make([]string, 0) // used to eliminate zero entries. + items = append(items, "[") + tabs := "\n" + strings.Repeat("\t", depth) + for i := 0; i < len(x.Contents); i++ { + item, ok := format(graph, depth+1, x.Contents[i], html) + if !ok { + zeros = append(zeros, fmt.Sprintf("\t%s,", item)) + continue + } + if len(zeros) > 0 { + items = append(items, zeros...) + zeros = nil + } + items = append(items, fmt.Sprintf("\t%s,", item)) + } + if len(zeros) > 0 { + items = append(items, fmt.Sprintf("\t... (%d zeros),", len(zeros))) + } + items = append(items, "]") + return strings.Join(items, tabs), len(zeros) < len(x.Contents) + case *wire.Struct: + typ, _ := formatType(x.TypeID, graph, html) + if x.Fields() == 0 { + return fmt.Sprintf("struct[%s]{}", typ), false + } + items := make([]string, 0, 2) + items = append(items, fmt.Sprintf("struct[%s]{", typ)) + tabs := "\n" + strings.Repeat("\t", depth) + allZero := true + for i := 0; i < x.Fields(); i++ { + element, ok := format(graph, depth+1, *x.Field(i), html) + allZero = allZero && !ok + items = append(items, fmt.Sprintf("\t%d: %s,", i, element)) + i++ + } + items = append(items, "}") + return strings.Join(items, tabs), !allZero + case *wire.Map: + if len(x.Keys) == 0 { + return "map{}", false + } + items := make([]string, 0, len(x.Keys)+2) + items = append(items, "map{") + tabs := "\n" + strings.Repeat("\t", depth) + for i := 0; i < len(x.Keys); i++ { + key, _ := format(graph, depth+1, x.Keys[i], html) + value, _ := format(graph, depth+1, x.Values[i], html) + items = append(items, fmt.Sprintf("\t%s: %s,", key, value)) + } + items = append(items, "}") + return strings.Join(items, tabs), true + case *wire.Interface: + typ, typOk := formatType(x.Type, graph, html) + element, elementOk := format(graph, depth+1, x.Value, html) + return fmt.Sprintf("interface[%s]{%s}", typ, element), typOk || elementOk + default: + // Must be a primitive; use reflection. + return fmt.Sprintf("%v", encoded), true + } +} + +// printStream is the basic print implementation. +func printStream(w io.Writer, r wire.Reader, html bool) (err error) { + // current graph ID. + var graph uint64 + + if html { + fmt.Fprintf(w, "<pre>") + defer fmt.Fprintf(w, "</pre>") + } + + defer func() { + if r := recover(); r != nil { + if rErr, ok := r.(error); ok { + err = rErr // Override return. + return + } + panic(r) // Propagate. + } + }() + + for { + // Find the first object to begin generation. + length, object, err := state.ReadHeader(r) + if err == io.EOF { + // Nothing else to do. + break + } else if err != nil { + return err + } + if !object { + graph++ // Increment the graph. + if length > 0 { + fmt.Fprintf(w, "(%d bytes non-object data)\n", length) + io.Copy(ioutil.Discard, &io.LimitedReader{ + R: r, + N: int64(length), + }) + } + continue + } + + // Read & unmarshal the object. + // + // Note that this loop must match the general structure of the + // loop in decode.go. But we don't register type information, + // etc. and just print the raw structures. + var ( + oid uint64 = 1 + tid uint64 = 1 + ) + for oid <= length { + // Unmarshal the object. + encoded := wire.Load(r) + + // Is this a type? + if _, ok := encoded.(*wire.Type); ok { + str, _ := format(graph, 0, encoded, html) + tag := fmt.Sprintf("g%dt%d", graph, tid) + if html { + // See below. + tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) + } + if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil { + return err + } + tid++ + continue + } + + // Format the node. + str, _ := format(graph, 0, encoded, html) + tag := fmt.Sprintf("g%dr%d", graph, oid) + if html { + // Create a little tag with an anchor next to it for linking. + tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) + } + if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil { + return err + } + oid++ + } + } + + return nil +} + +// PrintText reads the stream from r and prints text to w. +func PrintText(w io.Writer, r wire.Reader) error { + return printStream(w, r, false /* html */) +} + +// PrintHTML reads the stream from r and prints html to w. +func PrintHTML(w io.Writer, r wire.Reader) error { + return printStream(w, r, true /* html */) +} diff --git a/pkg/state/printer.go b/pkg/state/printer.go deleted file mode 100644 index 3ce18242f..000000000 --- a/pkg/state/printer.go +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package state - -import ( - "fmt" - "io" - "io/ioutil" - "reflect" - "strings" - - "github.com/golang/protobuf/proto" - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" -) - -// format formats a single object, for pretty-printing. It also returns whether -// the value is a non-zero value. -func format(graph uint64, depth int, object *pb.Object, html bool) (string, bool) { - switch x := object.GetValue().(type) { - case *pb.Object_BoolValue: - return fmt.Sprintf("%t", x.BoolValue), x.BoolValue != false - case *pb.Object_StringValue: - return fmt.Sprintf("\"%s\"", string(x.StringValue)), len(x.StringValue) != 0 - case *pb.Object_Int64Value: - return fmt.Sprintf("%d", x.Int64Value), x.Int64Value != 0 - case *pb.Object_Uint64Value: - return fmt.Sprintf("%du", x.Uint64Value), x.Uint64Value != 0 - case *pb.Object_DoubleValue: - return fmt.Sprintf("%f", x.DoubleValue), x.DoubleValue != 0.0 - case *pb.Object_RefValue: - if x.RefValue == 0 { - return "nil", false - } - ref := fmt.Sprintf("g%dr%d", graph, x.RefValue) - if html { - ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref) - } - return ref, true - case *pb.Object_SliceValue: - if x.SliceValue.RefValue == 0 { - return "nil", false - } - ref := fmt.Sprintf("g%dr%d", graph, x.SliceValue.RefValue) - if html { - ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref) - } - return fmt.Sprintf("%s[:%d:%d]", ref, x.SliceValue.Length, x.SliceValue.Capacity), true - case *pb.Object_ArrayValue: - if len(x.ArrayValue.Contents) == 0 { - return "[]", false - } - items := make([]string, 0, len(x.ArrayValue.Contents)+2) - zeros := make([]string, 0) // used to eliminate zero entries. - items = append(items, "[") - tabs := "\n" + strings.Repeat("\t", depth) - for i := 0; i < len(x.ArrayValue.Contents); i++ { - item, ok := format(graph, depth+1, x.ArrayValue.Contents[i], html) - if ok { - if len(zeros) > 0 { - items = append(items, zeros...) - zeros = nil - } - items = append(items, fmt.Sprintf("\t%s,", item)) - } else { - zeros = append(zeros, fmt.Sprintf("\t%s,", item)) - } - } - if len(zeros) > 0 { - items = append(items, fmt.Sprintf("\t... (%d zeros),", len(zeros))) - } - items = append(items, "]") - return strings.Join(items, tabs), len(zeros) < len(x.ArrayValue.Contents) - case *pb.Object_StructValue: - if len(x.StructValue.Fields) == 0 { - return "struct{}", false - } - items := make([]string, 0, len(x.StructValue.Fields)+2) - items = append(items, "struct{") - tabs := "\n" + strings.Repeat("\t", depth) - allZero := true - for _, field := range x.StructValue.Fields { - element, ok := format(graph, depth+1, field.Value, html) - allZero = allZero && !ok - items = append(items, fmt.Sprintf("\t%s: %s,", field.Name, element)) - } - items = append(items, "}") - return strings.Join(items, tabs), !allZero - case *pb.Object_MapValue: - if len(x.MapValue.Keys) == 0 { - return "map{}", false - } - items := make([]string, 0, len(x.MapValue.Keys)+2) - items = append(items, "map{") - tabs := "\n" + strings.Repeat("\t", depth) - for i := 0; i < len(x.MapValue.Keys); i++ { - key, _ := format(graph, depth+1, x.MapValue.Keys[i], html) - value, _ := format(graph, depth+1, x.MapValue.Values[i], html) - items = append(items, fmt.Sprintf("\t%s: %s,", key, value)) - } - items = append(items, "}") - return strings.Join(items, tabs), true - case *pb.Object_InterfaceValue: - if x.InterfaceValue.Type == "" { - return "interface(nil){}", false - } - element, _ := format(graph, depth+1, x.InterfaceValue.Value, html) - return fmt.Sprintf("interface(\"%s\"){%s}", x.InterfaceValue.Type, element), true - case *pb.Object_ByteArrayValue: - return printArray(reflect.ValueOf(x.ByteArrayValue)) - case *pb.Object_Uint16ArrayValue: - return printArray(reflect.ValueOf(x.Uint16ArrayValue.Values)) - case *pb.Object_Uint32ArrayValue: - return printArray(reflect.ValueOf(x.Uint32ArrayValue.Values)) - case *pb.Object_Uint64ArrayValue: - return printArray(reflect.ValueOf(x.Uint64ArrayValue.Values)) - case *pb.Object_UintptrArrayValue: - return printArray(castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0)))) - case *pb.Object_Int8ArrayValue: - return printArray(castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0)))) - case *pb.Object_Int16ArrayValue: - return printArray(reflect.ValueOf(x.Int16ArrayValue.Values)) - case *pb.Object_Int32ArrayValue: - return printArray(reflect.ValueOf(x.Int32ArrayValue.Values)) - case *pb.Object_Int64ArrayValue: - return printArray(reflect.ValueOf(x.Int64ArrayValue.Values)) - case *pb.Object_BoolArrayValue: - return printArray(reflect.ValueOf(x.BoolArrayValue.Values)) - case *pb.Object_Float64ArrayValue: - return printArray(reflect.ValueOf(x.Float64ArrayValue.Values)) - case *pb.Object_Float32ArrayValue: - return printArray(reflect.ValueOf(x.Float32ArrayValue.Values)) - } - - // Should not happen, but tolerate. - return fmt.Sprintf("(unknown proto type: %T)", object.GetValue()), true -} - -// PrettyPrint reads the state stream from r, and pretty prints to w. -func PrettyPrint(w io.Writer, r io.Reader, html bool) error { - var ( - // current graph ID. - graph uint64 - - // current object ID. - id uint64 - ) - - if html { - fmt.Fprintf(w, "<pre>") - defer fmt.Fprintf(w, "</pre>") - } - - for { - // Find the first object to begin generation. - length, object, err := ReadHeader(r) - if err == io.EOF { - // Nothing else to do. - break - } else if err != nil { - return err - } - if !object { - // Increment the graph number & reset the ID. - graph++ - id = 0 - if length > 0 { - fmt.Fprintf(w, "(%d bytes non-object data)\n", length) - io.Copy(ioutil.Discard, &io.LimitedReader{ - R: r, - N: int64(length), - }) - } - continue - } - - // Read & unmarshal the object. - buf := make([]byte, length) - for done := 0; done < len(buf); { - n, err := r.Read(buf[done:]) - done += n - if n == 0 && err != nil { - return err - } - } - obj := new(pb.Object) - if err := proto.Unmarshal(buf, obj); err != nil { - return err - } - - id++ // First object must be one. - str, _ := format(graph, 0, obj, html) - tag := fmt.Sprintf("g%dr%d", graph, id) - if html { - tag = fmt.Sprintf("<a name=%s>%s</a>", tag, tag) - } - if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil { - return err - } - } - - return nil -} - -func printArray(s reflect.Value) (string, bool) { - zero := reflect.Zero(s.Type().Elem()).Interface() - z := "0" - switch s.Type().Elem().Kind() { - case reflect.Bool: - z = "false" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - case reflect.Float32, reflect.Float64: - default: - return fmt.Sprintf("unexpected non-primitive type array: %#v", s.Interface()), true - } - - zeros := 0 - items := make([]string, 0, s.Len()) - for i := 0; i <= s.Len(); i++ { - if i < s.Len() && reflect.DeepEqual(s.Index(i).Interface(), zero) { - zeros++ - continue - } - if zeros > 0 { - if zeros <= 4 { - for ; zeros > 0; zeros-- { - items = append(items, z) - } - } else { - items = append(items, fmt.Sprintf("(%d %ss)", zeros, z)) - zeros = 0 - } - } - if i < s.Len() { - items = append(items, fmt.Sprintf("%v", s.Index(i).Interface())) - } - } - return "[" + strings.Join(items, ",") + "]", zeros < s.Len() -} diff --git a/pkg/state/state.go b/pkg/state/state.go index 03ae2dbb0..acb629969 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -31,210 +31,226 @@ // Uint64 default // Float32 default // Float64 default -// Complex64 custom -// Complex128 custom +// Complex64 default +// Complex128 default // Array default // Chan custom // Func custom -// Interface custom -// Map default (*) +// Interface default +// Map default // Ptr default // Slice default // String default -// Struct custom +// Struct custom (*) Unless zero-sized. // UnsafePointer custom // -// (*) Maps are treated as value types by this package, even if they are -// pointers internally. If you want to save two independent references -// to the same map value, you must explicitly use a pointer to a map. +// See README.md for an overview of how encoding and decoding works. package state import ( "context" "fmt" - "io" "reflect" "runtime" - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" + "gvisor.dev/gvisor/pkg/state/wire" ) +// objectID is a unique identifier assigned to each object to be serialized. +// Each instance of an object is considered separately, i.e. if there are two +// objects of the same type in the object graph being serialized, they'll be +// assigned unique objectIDs. +type objectID uint32 + +// typeID is the identifier for a type. Types are serialized and tracked +// alongside objects in order to avoid the overhead of encoding field names in +// all objects. +type typeID uint32 + // ErrState is returned when an error is encountered during encode/decode. type ErrState struct { // err is the underlying error. err error - // path is the visit path from root to the current object. - path string - // trace is the stack trace. trace string } // Error returns a sensible description of the state error. func (e *ErrState) Error() string { - return fmt.Sprintf("%v:\nstate path: %s\n%s", e.err, e.path, e.trace) + return fmt.Sprintf("%v:\n%s", e.err, e.trace) } -// UnwrapErrState returns the underlying error in ErrState. -// -// If err is not *ErrState, err is returned directly. -func UnwrapErrState(err error) error { - if e, ok := err.(*ErrState); ok { - return e.err - } - return err +// Unwrap implements standard unwrapping. +func (e *ErrState) Unwrap() error { + return e.err } // Save saves the given object state. -func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error { +func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) { // Create the encoding state. - es := &encodeState{ - ctx: ctx, - idsByObject: make(map[uintptr]uint64), - w: w, - stats: stats, + es := encodeState{ + ctx: ctx, + w: w, + types: makeTypeEncodeDatabase(), + zeroValues: make(map[reflect.Type]*objectEncodeState), } // Perform the encoding. - return es.safely(func() { - es.Serialize(reflect.ValueOf(rootPtr).Elem()) + err := safely(func() { + es.Save(reflect.ValueOf(rootPtr).Elem()) }) + return es.stats, err } // Load loads a checkpoint. -func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error { +func Load(ctx context.Context, r wire.Reader, rootPtr interface{}) (Stats, error) { // Create the decoding state. - ds := &decodeState{ - ctx: ctx, - objectsByID: make(map[uint64]*objectState), - deferred: make(map[uint64]*pb.Object), - r: r, - stats: stats, + ds := decodeState{ + ctx: ctx, + r: r, + types: makeTypeDecodeDatabase(), + deferred: make(map[objectID]wire.Object), } // Attempt our decode. - return ds.safely(func() { - ds.Deserialize(reflect.ValueOf(rootPtr).Elem()) + err := safely(func() { + ds.Load(reflect.ValueOf(rootPtr).Elem()) }) + return ds.stats, err } -// Fns are the state dispatch functions. -type Fns struct { - // Save is a function like Save(concreteType, Map). - Save interface{} - - // Load is a function like Load(concreteType, Map). - Load interface{} +// Sink is used for Type.StateSave. +type Sink struct { + internal objectEncoder } -// Save executes the save function. -func (fns *Fns) invokeSave(obj reflect.Value, m Map) { - reflect.ValueOf(fns.Save).Call([]reflect.Value{obj, reflect.ValueOf(m)}) +// Save adds the given object to the map. +// +// You should pass always pointers to the object you are saving. For example: +// +// type X struct { +// A int +// B *int +// } +// +// func (x *X) StateTypeInfo(m Sink) state.TypeInfo { +// return state.TypeInfo{ +// Name: "pkg.X", +// Fields: []string{ +// "A", +// "B", +// }, +// } +// } +// +// func (x *X) StateSave(m Sink) { +// m.Save(0, &x.A) // Field is A. +// m.Save(1, &x.B) // Field is B. +// } +// +// func (x *X) StateLoad(m Source) { +// m.Load(0, &x.A) // Field is A. +// m.Load(1, &x.B) // Field is B. +// } +func (s Sink) Save(slot int, objPtr interface{}) { + s.internal.save(slot, reflect.ValueOf(objPtr).Elem()) } -// Load executes the load function. -func (fns *Fns) invokeLoad(obj reflect.Value, m Map) { - reflect.ValueOf(fns.Load).Call([]reflect.Value{obj, reflect.ValueOf(m)}) +// SaveValue adds the given object value to the map. +// +// This should be used for values where pointers are not available, or casts +// are required during Save/Load. +// +// For example, if we want to cast external package type P.Foo to int64: +// +// func (x *X) StateSave(m Sink) { +// m.SaveValue(0, "A", int64(x.A)) +// } +// +// func (x *X) StateLoad(m Source) { +// m.LoadValue(0, new(int64), func(x interface{}) { +// x.A = P.Foo(x.(int64)) +// }) +// } +func (s Sink) SaveValue(slot int, obj interface{}) { + s.internal.save(slot, reflect.ValueOf(obj)) } -// validateStateFn ensures types are correct. -func validateStateFn(fn interface{}, typ reflect.Type) bool { - fnTyp := reflect.TypeOf(fn) - if fnTyp.Kind() != reflect.Func { - return false - } - if fnTyp.NumIn() != 2 { - return false - } - if fnTyp.NumOut() != 0 { - return false - } - if fnTyp.In(0) != typ { - return false - } - if fnTyp.In(1) != reflect.TypeOf(Map{}) { - return false - } - return true +// Context returns the context object provided at save time. +func (s Sink) Context() context.Context { + return s.internal.es.ctx } -// Validate validates all state functions. -func (fns *Fns) Validate(typ reflect.Type) bool { - return validateStateFn(fns.Save, typ) && validateStateFn(fns.Load, typ) +// Type is an interface that must be implemented by Struct objects. This allows +// these objects to be serialized while minimizing runtime reflection required. +// +// All these methods can be automatically generated by the go_statify tool. +type Type interface { + // StateTypeName returns the type's name. + // + // This is used for matching type information during encoding and + // decoding, as well as dynamic interface dispatch. This should be + // globally unique. + StateTypeName() string + + // StateFields returns information about the type. + // + // Fields is the set of fields for the object. Calls to Sink.Save and + // Source.Load must be made in-order with respect to these fields. + // + // This will be called at most once per serialization. + StateFields() []string } -type typeDatabase struct { - // nameToType is a forward lookup table. - nameToType map[string]reflect.Type - - // typeToName is the reverse lookup table. - typeToName map[reflect.Type]string +// SaverLoader must be implemented by struct types. +type SaverLoader interface { + // StateSave saves the state of the object to the given Map. + StateSave(Sink) - // typeToFns is the function lookup table. - typeToFns map[reflect.Type]Fns + // StateLoad loads the state of the object. + StateLoad(Source) } -// registeredTypes is a database used for SaveInterface and LoadInterface. -var registeredTypes = typeDatabase{ - nameToType: make(map[string]reflect.Type), - typeToName: make(map[reflect.Type]string), - typeToFns: make(map[reflect.Type]Fns), +// Source is used for Type.StateLoad. +type Source struct { + internal objectDecoder } -// register registers a type under the given name. This will generally be -// called via init() methods, and therefore uses panic to propagate errors. -func (t *typeDatabase) register(name string, typ reflect.Type, fns Fns) { - // We can't allow name collisions. - if ot, ok := t.nameToType[name]; ok { - panic(fmt.Sprintf("type %q can't use name %q, already in use by type %q", typ.Name(), name, ot.Name())) - } - - // Or multiple registrations. - if on, ok := t.typeToName[typ]; ok { - panic(fmt.Sprintf("type %q can't be registered as %q, already registered as %q", typ.Name(), name, on)) - } - - t.nameToType[name] = typ - t.typeToName[typ] = name - t.typeToFns[typ] = fns +// Load loads the given object passed as a pointer.. +// +// See Sink.Save for an example. +func (s Source) Load(slot int, objPtr interface{}) { + s.internal.load(slot, reflect.ValueOf(objPtr), false, nil) } -// lookupType finds a type given a name. -func (t *typeDatabase) lookupType(name string) (reflect.Type, bool) { - typ, ok := t.nameToType[name] - return typ, ok +// LoadWait loads the given objects from the map, and marks it as requiring all +// AfterLoad executions to complete prior to running this object's AfterLoad. +// +// See Sink.Save for an example. +func (s Source) LoadWait(slot int, objPtr interface{}) { + s.internal.load(slot, reflect.ValueOf(objPtr), true, nil) } -// lookupName finds a name given a type. -func (t *typeDatabase) lookupName(typ reflect.Type) (string, bool) { - name, ok := t.typeToName[typ] - return name, ok +// LoadValue loads the given object value from the map. +// +// See Sink.SaveValue for an example. +func (s Source) LoadValue(slot int, objPtr interface{}, fn func(interface{})) { + o := reflect.ValueOf(objPtr) + s.internal.load(slot, o, true, func() { fn(o.Elem().Interface()) }) } -// lookupFns finds functions given a type. -func (t *typeDatabase) lookupFns(typ reflect.Type) (Fns, bool) { - fns, ok := t.typeToFns[typ] - return fns, ok +// AfterLoad schedules a function execution when all objects have been +// allocated and their automated loading and customized load logic have been +// executed. fn will not be executed until all of current object's +// dependencies' AfterLoad() logic, if exist, have been executed. +func (s Source) AfterLoad(fn func()) { + s.internal.afterLoad(fn) } -// Register must be called for any interface implementation types that -// implements Loader. -// -// Register should be called either immediately after startup or via init() -// methods. Double registration of either names or types will result in a panic. -// -// No synchronization is provided; this should only be called in init. -// -// Example usage: -// -// state.Register("Foo", (*Foo)(nil), state.Fns{ -// Save: (*Foo).Save, -// Load: (*Foo).Load, -// }) -// -func Register(name string, instance interface{}, fns Fns) { - registeredTypes.register(name, reflect.TypeOf(instance), fns) +// Context returns the context object provided at load time. +func (s Source) Context() context.Context { + return s.internal.ds.ctx } // IsZeroValue checks if the given value is the zero value. @@ -244,72 +260,14 @@ func IsZeroValue(val interface{}) bool { return val == nil || reflect.ValueOf(val).Elem().IsZero() } -// step captures one encoding / decoding step. On each step, there is up to one -// choice made, which is captured by non-nil param. We intentionally do not -// eagerly create the final path string, as that will only be needed upon panic. -type step struct { - // dereference indicate if the current object is obtained by - // dereferencing a pointer. - dereference bool - - // format is the formatting string that takes param below, if - // non-nil. For example, in array indexing case, we have "[%d]". - format string - - // param stores the choice made at the current encoding / decoding step. - // For eaxmple, in array indexing case, param stores the index. When no - // choice is made, e.g. dereference, param should be nil. - param interface{} -} - -// recoverable is the state encoding / decoding panic recovery facility. It is -// also used to store encoding / decoding steps as well as the reference to the -// original queued object from which the current object is dispatched. The -// complete encoding / decoding path is synthesised from the steps in all queued -// objects leading to the current object. -type recoverable struct { - from *recoverable - steps []step +// Failf is a wrapper around panic that should be used to generate errors that +// can be caught during saving and loading. +func Failf(fmtStr string, v ...interface{}) { + panic(fmt.Errorf(fmtStr, v...)) } -// push enters a new context level. -func (sr *recoverable) push(dereference bool, format string, param interface{}) { - sr.steps = append(sr.steps, step{dereference, format, param}) -} - -// pop exits the current context level. -func (sr *recoverable) pop() { - if len(sr.steps) <= 1 { - return - } - sr.steps = sr.steps[:len(sr.steps)-1] -} - -// path returns the complete encoding / decoding path from root. This is only -// called upon panic. -func (sr *recoverable) path() string { - if sr.from == nil { - return "root" - } - p := sr.from.path() - for _, s := range sr.steps { - if s.dereference { - p = fmt.Sprintf("*(%s)", p) - } - if s.param == nil { - p += s.format - } else { - p += fmt.Sprintf(s.format, s.param) - } - } - return p -} - -func (sr *recoverable) copy() recoverable { - return recoverable{from: sr.from, steps: append([]step(nil), sr.steps...)} -} - -// safely executes the given function, catching a panic and unpacking as an error. +// safely executes the given function, catching a panic and unpacking as an +// error. // // The error flow through the state package uses panic and recover. There are // two important reasons for this: @@ -323,9 +281,15 @@ func (sr *recoverable) copy() recoverable { // method doesn't add a lot of value. If there are specific error conditions // that you'd like to handle, you should add appropriate functionality to // objects themselves prior to calling Save() and Load(). -func (sr *recoverable) safely(fn func()) (err error) { +func safely(fn func()) (err error) { defer func() { if r := recover(); r != nil { + if es, ok := r.(*ErrState); ok { + err = es // Propagate. + return + } + + // Build a new state error. es := new(ErrState) if e, ok := r.(error); ok { es.err = e @@ -333,8 +297,6 @@ func (sr *recoverable) safely(fn func()) (err error) { es.err = fmt.Errorf("%v", r) } - es.path = sr.path() - // Make a stack. We don't know how big it will be ahead // of time, but want to make sure we get the whole // thing. So we just do a stupid brute force approach. diff --git a/pkg/state/state_norace.go b/pkg/state/state_norace.go new file mode 100644 index 000000000..4281aed6d --- /dev/null +++ b/pkg/state/state_norace.go @@ -0,0 +1,19 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !race + +package state + +var raceEnabled = false diff --git a/pkg/state/state_race.go b/pkg/state/state_race.go new file mode 100644 index 000000000..8232981ce --- /dev/null +++ b/pkg/state/state_race.go @@ -0,0 +1,19 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build race + +package state + +var raceEnabled = true diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go deleted file mode 100644 index d7221e9e8..000000000 --- a/pkg/state/state_test.go +++ /dev/null @@ -1,721 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package state - -import ( - "bytes" - "context" - "io/ioutil" - "math" - "reflect" - "testing" -) - -// TestCase is used to define a single success/failure testcase of -// serialization of a set of objects. -type TestCase struct { - // Name is the name of the test case. - Name string - - // Objects is the list of values to serialize. - Objects []interface{} - - // Fail is whether the test case is supposed to fail or not. - Fail bool -} - -// runTest runs all testcases. -func runTest(t *testing.T, tests []TestCase) { - for _, test := range tests { - t.Logf("TEST %s:", test.Name) - for i, root := range test.Objects { - t.Logf(" case#%d: %#v", i, root) - - // Save the passed object. - saveBuffer := &bytes.Buffer{} - saveObjectPtr := reflect.New(reflect.TypeOf(root)) - saveObjectPtr.Elem().Set(reflect.ValueOf(root)) - if err := Save(context.Background(), saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { - t.Errorf(" FAIL: Save failed unexpectedly: %v", err) - continue - } else if err != nil { - t.Logf(" PASS: Save failed as expected: %v", err) - continue - } - - // Load a new copy of the object. - loadObjectPtr := reflect.New(reflect.TypeOf(root)) - if err := Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { - t.Errorf(" FAIL: Load failed unexpectedly: %v", err) - continue - } else if err != nil { - t.Logf(" PASS: Load failed as expected: %v", err) - continue - } - - // Compare the values. - loadedValue := loadObjectPtr.Elem().Interface() - if eq := reflect.DeepEqual(root, loadedValue); !eq && !test.Fail { - t.Errorf(" FAIL: Objects differs; got %#v", loadedValue) - continue - } else if !eq { - t.Logf(" PASS: Object different as expected.") - continue - } - - // Everything went okay. Is that good? - if test.Fail { - t.Errorf(" FAIL: Unexpected success.") - } else { - t.Logf(" PASS: Success.") - } - } - } -} - -// dumbStruct is a struct which does not implement the loader/saver interface. -// We expect that serialization of this struct will fail. -type dumbStruct struct { - A int - B int -} - -// smartStruct is a struct which does implement the loader/saver interface. -// We expect that serialization of this struct will succeed. -type smartStruct struct { - A int - B int -} - -func (s *smartStruct) save(m Map) { - m.Save("A", &s.A) - m.Save("B", &s.B) -} - -func (s *smartStruct) load(m Map) { - m.Load("A", &s.A) - m.Load("B", &s.B) -} - -// valueLoadStruct uses a value load. -type valueLoadStruct struct { - v int -} - -func (v *valueLoadStruct) save(m Map) { - m.SaveValue("v", v.v) -} - -func (v *valueLoadStruct) load(m Map) { - m.LoadValue("v", new(int), func(value interface{}) { - v.v = value.(int) - }) -} - -// afterLoadStruct has an AfterLoad function. -type afterLoadStruct struct { - v int -} - -func (a *afterLoadStruct) save(m Map) { -} - -func (a *afterLoadStruct) load(m Map) { - m.AfterLoad(func() { - a.v++ - }) -} - -// genericContainer is a generic dispatcher. -type genericContainer struct { - v interface{} -} - -func (g *genericContainer) save(m Map) { - m.Save("v", &g.v) -} - -func (g *genericContainer) load(m Map) { - m.Load("v", &g.v) -} - -// sliceContainer is a generic slice. -type sliceContainer struct { - v []interface{} -} - -func (s *sliceContainer) save(m Map) { - m.Save("v", &s.v) -} - -func (s *sliceContainer) load(m Map) { - m.Load("v", &s.v) -} - -// mapContainer is a generic map. -type mapContainer struct { - v map[int]interface{} -} - -func (mc *mapContainer) save(m Map) { - m.Save("v", &mc.v) -} - -func (mc *mapContainer) load(m Map) { - // Some of the test cases below assume legacy behavior wherein maps - // will automatically inherit dependencies. - m.LoadWait("v", &mc.v) -} - -// dumbMap is a map which does not implement the loader/saver interface. -// Serialization of this map will default to the standard encode/decode logic. -type dumbMap map[string]int - -// pointerStruct contains various pointers, shared and non-shared, and pointers -// to pointers. We expect that serialization will respect the structure. -type pointerStruct struct { - A *int - B *int - C *int - D *int - - AA **int - BB **int -} - -func (p *pointerStruct) save(m Map) { - m.Save("A", &p.A) - m.Save("B", &p.B) - m.Save("C", &p.C) - m.Save("D", &p.D) - m.Save("AA", &p.AA) - m.Save("BB", &p.BB) -} - -func (p *pointerStruct) load(m Map) { - m.Load("A", &p.A) - m.Load("B", &p.B) - m.Load("C", &p.C) - m.Load("D", &p.D) - m.Load("AA", &p.AA) - m.Load("BB", &p.BB) -} - -// testInterface is a trivial interface example. -type testInterface interface { - Foo() -} - -// testImpl is a trivial implementation of testInterface. -type testImpl struct { -} - -// Foo satisfies testInterface. -func (t *testImpl) Foo() { -} - -// testImpl is trivially serializable. -func (t *testImpl) save(m Map) { -} - -// testImpl is trivially serializable. -func (t *testImpl) load(m Map) { -} - -// testI demonstrates interface dispatching. -type testI struct { - I testInterface -} - -func (t *testI) save(m Map) { - m.Save("I", &t.I) -} - -func (t *testI) load(m Map) { - m.Load("I", &t.I) -} - -// cycleStruct is used to implement basic cycles. -type cycleStruct struct { - c *cycleStruct -} - -func (c *cycleStruct) save(m Map) { - m.Save("c", &c.c) -} - -func (c *cycleStruct) load(m Map) { - m.Load("c", &c.c) -} - -// badCycleStruct actually has deadlocking dependencies. -// -// This should pass if b.b = {nil|b} and fail otherwise. -type badCycleStruct struct { - b *badCycleStruct -} - -func (b *badCycleStruct) save(m Map) { - m.Save("b", &b.b) -} - -func (b *badCycleStruct) load(m Map) { - m.LoadWait("b", &b.b) - m.AfterLoad(func() { - // This is not executable, since AfterLoad requires that the - // object and all dependencies are complete. This should cause - // a deadlock error during load. - }) -} - -// emptyStructPointer points to an empty struct. -type emptyStructPointer struct { - nothing *struct{} -} - -func (e *emptyStructPointer) save(m Map) { - m.Save("nothing", &e.nothing) -} - -func (e *emptyStructPointer) load(m Map) { - m.Load("nothing", &e.nothing) -} - -// truncateInteger truncates an integer. -type truncateInteger struct { - v int64 - v2 int32 -} - -func (t *truncateInteger) save(m Map) { - t.v2 = int32(t.v) - m.Save("v", &t.v) -} - -func (t *truncateInteger) load(m Map) { - m.Load("v", &t.v2) - t.v = int64(t.v2) -} - -// truncateUnsignedInteger truncates an unsigned integer. -type truncateUnsignedInteger struct { - v uint64 - v2 uint32 -} - -func (t *truncateUnsignedInteger) save(m Map) { - t.v2 = uint32(t.v) - m.Save("v", &t.v) -} - -func (t *truncateUnsignedInteger) load(m Map) { - m.Load("v", &t.v2) - t.v = uint64(t.v2) -} - -// truncateFloat truncates a floating point number. -type truncateFloat struct { - v float64 - v2 float32 -} - -func (t *truncateFloat) save(m Map) { - t.v2 = float32(t.v) - m.Save("v", &t.v) -} - -func (t *truncateFloat) load(m Map) { - m.Load("v", &t.v2) - t.v = float64(t.v2) -} - -func TestTypes(t *testing.T) { - // x and y are basic integers, while xp points to x. - x := 1 - y := 2 - xp := &x - - // cs is a single object cycle. - cs := cycleStruct{nil} - cs.c = &cs - - // cs1 and cs2 are in a two object cycle. - cs1 := cycleStruct{nil} - cs2 := cycleStruct{nil} - cs1.c = &cs2 - cs2.c = &cs1 - - // bs is a single object cycle. - bs := badCycleStruct{nil} - bs.b = &bs - - // bs2 and bs2 are in a deadlocking cycle. - bs1 := badCycleStruct{nil} - bs2 := badCycleStruct{nil} - bs1.b = &bs2 - bs2.b = &bs1 - - // regular nils. - var ( - nilmap dumbMap - nilslice []byte - ) - - // embed points to embedded fields. - embed1 := pointerStruct{} - embed1.AA = &embed1.A - embed2 := pointerStruct{} - embed2.BB = &embed2.B - - // es1 contains two structs pointing to the same empty struct. - es := emptyStructPointer{new(struct{})} - es1 := []emptyStructPointer{es, es} - - tests := []TestCase{ - { - Name: "bool", - Objects: []interface{}{ - true, - false, - }, - }, - { - Name: "integers", - Objects: []interface{}{ - int(0), - int(1), - int(-1), - int8(0), - int8(1), - int8(-1), - int16(0), - int16(1), - int16(-1), - int32(0), - int32(1), - int32(-1), - int64(0), - int64(1), - int64(-1), - }, - }, - { - Name: "unsigned integers", - Objects: []interface{}{ - uint(0), - uint(1), - uint8(0), - uint8(1), - uint16(0), - uint16(1), - uint32(1), - uint64(0), - uint64(1), - }, - }, - { - Name: "strings", - Objects: []interface{}{ - "", - "foo", - "bar", - "\xa0", - }, - }, - { - Name: "slices", - Objects: []interface{}{ - []int{-1, 0, 1}, - []*int{&x, &x, &x}, - []int{1, 2, 3}[0:1], - []int{1, 2, 3}[1:2], - make([]byte, 32), - make([]byte, 32)[:16], - make([]byte, 32)[:16:20], - nilslice, - }, - }, - { - Name: "arrays", - Objects: []interface{}{ - &[1048576]bool{false, true, false, true}, - &[1048576]uint8{0, 1, 2, 3}, - &[1048576]byte{0, 1, 2, 3}, - &[1048576]uint16{0, 1, 2, 3}, - &[1048576]uint{0, 1, 2, 3}, - &[1048576]uint32{0, 1, 2, 3}, - &[1048576]uint64{0, 1, 2, 3}, - &[1048576]uintptr{0, 1, 2, 3}, - &[1048576]int8{0, -1, -2, -3}, - &[1048576]int16{0, -1, -2, -3}, - &[1048576]int32{0, -1, -2, -3}, - &[1048576]int64{0, -1, -2, -3}, - &[1048576]float32{0, 1.1, 2.2, 3.3}, - &[1048576]float64{0, 1.1, 2.2, 3.3}, - }, - }, - { - Name: "pointers", - Objects: []interface{}{ - &pointerStruct{A: &x, B: &x, C: &y, D: &y, AA: &xp, BB: &xp}, - &pointerStruct{}, - }, - }, - { - Name: "empty struct", - Objects: []interface{}{ - struct{}{}, - }, - }, - { - Name: "unenlightened structs", - Objects: []interface{}{ - &dumbStruct{A: 1, B: 2}, - }, - Fail: true, - }, - { - Name: "enlightened structs", - Objects: []interface{}{ - &smartStruct{A: 1, B: 2}, - }, - }, - { - Name: "load-hooks", - Objects: []interface{}{ - &afterLoadStruct{v: 1}, - &valueLoadStruct{v: 1}, - &genericContainer{v: &afterLoadStruct{v: 1}}, - &genericContainer{v: &valueLoadStruct{v: 1}}, - &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}}, - &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}}, - }, - }, - { - Name: "maps", - Objects: []interface{}{ - dumbMap{"a": -1, "b": 0, "c": 1}, - map[smartStruct]int{{}: 0, {A: 1}: 1}, - nilmap, - &mapContainer{v: map[int]interface{}{0: &smartStruct{A: 1}}}, - }, - }, - { - Name: "interfaces", - Objects: []interface{}{ - &testI{&testImpl{}}, - &testI{nil}, - &testI{(*testImpl)(nil)}, - }, - }, - { - Name: "unregistered-interfaces", - Objects: []interface{}{ - &genericContainer{v: afterLoadStruct{v: 1}}, - &genericContainer{v: valueLoadStruct{v: 1}}, - &sliceContainer{v: []interface{}{afterLoadStruct{v: 1}}}, - &sliceContainer{v: []interface{}{valueLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: afterLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: valueLoadStruct{v: 1}}}, - }, - Fail: true, - }, - { - Name: "cycles", - Objects: []interface{}{ - &cs, - &cs1, - &cycleStruct{&cs1}, - &cycleStruct{&cs}, - &badCycleStruct{nil}, - &bs, - }, - }, - { - Name: "deadlock", - Objects: []interface{}{ - &bs1, - }, - Fail: true, - }, - { - Name: "embed", - Objects: []interface{}{ - &embed1, - &embed2, - }, - Fail: true, - }, - { - Name: "empty structs", - Objects: []interface{}{ - new(struct{}), - es, - es1, - }, - }, - { - Name: "truncated okay", - Objects: []interface{}{ - &truncateInteger{v: 1}, - &truncateUnsignedInteger{v: 1}, - &truncateFloat{v: 1.0}, - }, - }, - { - Name: "truncated bad", - Objects: []interface{}{ - &truncateInteger{v: math.MaxInt32 + 1}, - &truncateUnsignedInteger{v: math.MaxUint32 + 1}, - &truncateFloat{v: math.MaxFloat32 * 2}, - }, - Fail: true, - }, - } - - runTest(t, tests) -} - -// benchStruct is used for benchmarking. -type benchStruct struct { - b *benchStruct - - // Dummy data is included to ensure that these objects are large. - // This is to detect possible regression when registering objects. - _ [4096]byte -} - -func (b *benchStruct) save(m Map) { - m.Save("b", &b.b) -} - -func (b *benchStruct) load(m Map) { - m.LoadWait("b", &b.b) - m.AfterLoad(b.afterLoad) -} - -func (b *benchStruct) afterLoad() { - // Do nothing, just force scheduling. -} - -// buildObject builds a benchmark object. -func buildObject(n int) (b *benchStruct) { - for i := 0; i < n; i++ { - b = &benchStruct{b: b} - } - return -} - -func BenchmarkEncoding(b *testing.B) { - b.StopTimer() - bs := buildObject(b.N) - var stats Stats - b.StartTimer() - if err := Save(context.Background(), ioutil.Discard, bs, &stats); err != nil { - b.Errorf("save failed: %v", err) - } - b.StopTimer() - if b.N > 1000 { - b.Logf("breakdown (n=%d): %s", b.N, &stats) - } -} - -func BenchmarkDecoding(b *testing.B) { - b.StopTimer() - bs := buildObject(b.N) - var newBS benchStruct - buf := &bytes.Buffer{} - if err := Save(context.Background(), buf, bs, nil); err != nil { - b.Errorf("save failed: %v", err) - } - var stats Stats - b.StartTimer() - if err := Load(context.Background(), buf, &newBS, &stats); err != nil { - b.Errorf("load failed: %v", err) - } - b.StopTimer() - if b.N > 1000 { - b.Logf("breakdown (n=%d): %s", b.N, &stats) - } -} - -func init() { - Register("stateTest.smartStruct", (*smartStruct)(nil), Fns{ - Save: (*smartStruct).save, - Load: (*smartStruct).load, - }) - Register("stateTest.afterLoadStruct", (*afterLoadStruct)(nil), Fns{ - Save: (*afterLoadStruct).save, - Load: (*afterLoadStruct).load, - }) - Register("stateTest.valueLoadStruct", (*valueLoadStruct)(nil), Fns{ - Save: (*valueLoadStruct).save, - Load: (*valueLoadStruct).load, - }) - Register("stateTest.genericContainer", (*genericContainer)(nil), Fns{ - Save: (*genericContainer).save, - Load: (*genericContainer).load, - }) - Register("stateTest.sliceContainer", (*sliceContainer)(nil), Fns{ - Save: (*sliceContainer).save, - Load: (*sliceContainer).load, - }) - Register("stateTest.mapContainer", (*mapContainer)(nil), Fns{ - Save: (*mapContainer).save, - Load: (*mapContainer).load, - }) - Register("stateTest.pointerStruct", (*pointerStruct)(nil), Fns{ - Save: (*pointerStruct).save, - Load: (*pointerStruct).load, - }) - Register("stateTest.testImpl", (*testImpl)(nil), Fns{ - Save: (*testImpl).save, - Load: (*testImpl).load, - }) - Register("stateTest.testI", (*testI)(nil), Fns{ - Save: (*testI).save, - Load: (*testI).load, - }) - Register("stateTest.cycleStruct", (*cycleStruct)(nil), Fns{ - Save: (*cycleStruct).save, - Load: (*cycleStruct).load, - }) - Register("stateTest.badCycleStruct", (*badCycleStruct)(nil), Fns{ - Save: (*badCycleStruct).save, - Load: (*badCycleStruct).load, - }) - Register("stateTest.emptyStructPointer", (*emptyStructPointer)(nil), Fns{ - Save: (*emptyStructPointer).save, - Load: (*emptyStructPointer).load, - }) - Register("stateTest.truncateInteger", (*truncateInteger)(nil), Fns{ - Save: (*truncateInteger).save, - Load: (*truncateInteger).load, - }) - Register("stateTest.truncateUnsignedInteger", (*truncateUnsignedInteger)(nil), Fns{ - Save: (*truncateUnsignedInteger).save, - Load: (*truncateUnsignedInteger).load, - }) - Register("stateTest.truncateFloat", (*truncateFloat)(nil), Fns{ - Save: (*truncateFloat).save, - Load: (*truncateFloat).load, - }) - Register("stateTest.benchStruct", (*benchStruct)(nil), Fns{ - Save: (*benchStruct).save, - Load: (*benchStruct).load, - }) -} diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index e7581c09b..d6c89c7e9 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/binary", "//pkg/compressio", + "//pkg/state/wire", ], ) diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go index c0f4c4954..bdfb800fb 100644 --- a/pkg/state/statefile/statefile.go +++ b/pkg/state/statefile/statefile.go @@ -57,6 +57,7 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/compressio" + "gvisor.dev/gvisor/pkg/state/wire" ) // keySize is the AES-256 key length. @@ -83,10 +84,16 @@ var ErrInvalidMetadataLength = fmt.Errorf("metadata length invalid, maximum size // ErrMetadataInvalid is returned if passed metadata is invalid. var ErrMetadataInvalid = fmt.Errorf("metadata invalid, can't start with _") +// WriteCloser is an io.Closer and wire.Writer. +type WriteCloser interface { + wire.Writer + io.Closer +} + // NewWriter returns a state data writer for a statefile. // // Note that the returned WriteCloser must be closed. -func NewWriter(w io.Writer, key []byte, metadata map[string]string) (io.WriteCloser, error) { +func NewWriter(w io.Writer, key []byte, metadata map[string]string) (WriteCloser, error) { if metadata == nil { metadata = make(map[string]string) } @@ -215,7 +222,7 @@ func metadata(r io.Reader, h hash.Hash) (map[string]string, error) { } // NewReader returns a reader for a statefile. -func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) { +func NewReader(r io.Reader, key []byte) (wire.Reader, map[string]string, error) { // Read the metadata with the hash. h := hmac.New(sha256.New, key) metadata, err := metadata(r, h) @@ -224,9 +231,9 @@ func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) { } // Wrap in compression. - rc, err := compressio.NewReader(r, key) + cr, err := compressio.NewReader(r, key) if err != nil { return nil, nil, err } - return rc, metadata, nil + return cr, metadata, nil } diff --git a/pkg/state/stats.go b/pkg/state/stats.go index eb51cda47..eaec664a1 100644 --- a/pkg/state/stats.go +++ b/pkg/state/stats.go @@ -17,7 +17,6 @@ package state import ( "bytes" "fmt" - "reflect" "sort" "time" ) @@ -35,92 +34,81 @@ type statEntry struct { // All exported receivers accept nil. type Stats struct { // byType contains a breakdown of time spent by type. - byType map[reflect.Type]*statEntry + // + // This is indexed *directly* by typeID, including zero. + byType []statEntry // stack contains objects in progress. - stack []reflect.Type + stack []typeID + + // names contains type names. + // + // This is also indexed *directly* by typeID, including zero, which we + // hard-code as "state.default". This is only resolved by calling fini + // on the stats object. + names []string // last is the last start time. last time.Time } -// sample adds the samples to the given object. -func (s *Stats) sample(typ reflect.Type) { - now := time.Now() - s.byType[typ].total += now.Sub(s.last) - s.last = now +// init initializes statistics. +func (s *Stats) init() { + s.last = time.Now() + s.stack = append(s.stack, 0) } -// Add adds a sample count. -func (s *Stats) Add(obj reflect.Value) { - if s == nil { - return - } - if s.byType == nil { - s.byType = make(map[reflect.Type]*statEntry) - } - typ := obj.Type() - entry, ok := s.byType[typ] - if !ok { - entry = new(statEntry) - s.byType[typ] = entry +// fini finalizes statistics. +func (s *Stats) fini(resolve func(id typeID) string) { + s.done() + + // Resolve all type names. + s.names = make([]string, len(s.byType)) + s.names[0] = "state.default" // See above. + for id := typeID(1); int(id) < len(s.names); id++ { + s.names[id] = resolve(id) } - entry.count++ } -// Remove removes a sample count. It should only be called after a previous -// Add(). -func (s *Stats) Remove(obj reflect.Value) { - if s == nil { - return +// sample adds the samples to the given object. +func (s *Stats) sample(id typeID) { + now := time.Now() + if len(s.byType) <= int(id) { + // Allocate all the missing entries in one fell swoop. + s.byType = append(s.byType, make([]statEntry, 1+int(id)-len(s.byType))...) } - typ := obj.Type() - entry := s.byType[typ] - entry.count-- + s.byType[id].total += now.Sub(s.last) + s.last = now } -// Start starts a sample. -func (s *Stats) Start(obj reflect.Value) { - if s == nil { - return - } - if len(s.stack) > 0 { - last := s.stack[len(s.stack)-1] - s.sample(last) - } else { - // First time sample. - s.last = time.Now() - } - s.stack = append(s.stack, obj.Type()) +// start starts a sample. +func (s *Stats) start(id typeID) { + last := s.stack[len(s.stack)-1] + s.sample(last) + s.stack = append(s.stack, id) } -// Done finishes the current sample. -func (s *Stats) Done() { - if s == nil { - return - } +// done finishes the current sample. +func (s *Stats) done() { last := s.stack[len(s.stack)-1] s.sample(last) + s.byType[last].count++ s.stack = s.stack[:len(s.stack)-1] } type sliceEntry struct { - typ reflect.Type + name string entry *statEntry } // String returns a table representation of the stats. func (s *Stats) String() string { - if s == nil || len(s.byType) == 0 { - return "(no data)" - } - // Build a list of stat entries. ss := make([]sliceEntry, 0, len(s.byType)) - for typ, entry := range s.byType { + for id := 0; id < len(s.names); id++ { ss = append(ss, sliceEntry{ - typ: typ, - entry: entry, + name: s.names[id], + entry: &s.byType[id], }) } @@ -136,17 +124,22 @@ func (s *Stats) String() string { total time.Duration ) buf.WriteString("\n") - buf.WriteString(fmt.Sprintf("%12s | %8s | %8s | %s\n", "total", "count", "per", "type")) - buf.WriteString("-------------+----------+----------+-------------\n") + buf.WriteString(fmt.Sprintf("% 16s | % 8s | % 16s | %s\n", "total", "count", "per", "type")) + buf.WriteString("-----------------+----------+------------------+----------------\n") for _, se := range ss { + if se.entry.count == 0 { + // Since we store all types linearly, we are not + // guaranteed that any entry actually has time. + continue + } count += se.entry.count total += se.entry.total per := se.entry.total / time.Duration(se.entry.count) - buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | %s\n", - se.entry.total, se.entry.count, per, se.typ.String())) + buf.WriteString(fmt.Sprintf("% 16s | %8d | % 16s | %s\n", + se.entry.total, se.entry.count, per, se.name)) } - buf.WriteString("-------------+----------+----------+-------------\n") - buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | [all]", + buf.WriteString("-----------------+----------+------------------+----------------\n") + buf.WriteString(fmt.Sprintf("% 16s | % 8d | % 16s | [all]", total, count, total/time.Duration(count))) return string(buf.Bytes()) } diff --git a/pkg/state/tests/BUILD b/pkg/state/tests/BUILD new file mode 100644 index 000000000..9297cafbe --- /dev/null +++ b/pkg/state/tests/BUILD @@ -0,0 +1,43 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "tests", + srcs = [ + "array.go", + "bench.go", + "integer.go", + "load.go", + "map.go", + "register.go", + "struct.go", + "tests.go", + ], + deps = [ + "//pkg/state", + "//pkg/state/pretty", + ], +) + +go_test( + name = "tests_test", + size = "small", + srcs = [ + "array_test.go", + "bench_test.go", + "bool_test.go", + "float_test.go", + "integer_test.go", + "load_test.go", + "map_test.go", + "register_test.go", + "string_test.go", + "struct_test.go", + ], + library = ":tests", + deps = [ + "//pkg/state", + "//pkg/state/wire", + ], +) diff --git a/pkg/state/tests/array.go b/pkg/state/tests/array.go new file mode 100644 index 000000000..0972a80e7 --- /dev/null +++ b/pkg/state/tests/array.go @@ -0,0 +1,35 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +// +stateify savable +type arrayContainer struct { + v [1]interface{} +} + +// +stateify savable +type arrayPtrContainer struct { + v *[1]interface{} +} + +// +stateify savable +type sliceContainer struct { + v []interface{} +} + +// +stateify savable +type slicePtrContainer struct { + v *[]interface{} +} diff --git a/pkg/state/tests/array_test.go b/pkg/state/tests/array_test.go new file mode 100644 index 000000000..a347b2947 --- /dev/null +++ b/pkg/state/tests/array_test.go @@ -0,0 +1,134 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "reflect" + "testing" +) + +var allArrayPrimitives = []interface{}{ + [1]bool{}, + [1]bool{true}, + [2]bool{false, true}, + [1]int{}, + [1]int{1}, + [2]int{0, 1}, + [1]int8{}, + [1]int8{1}, + [2]int8{0, 1}, + [1]int16{}, + [1]int16{1}, + [2]int16{0, 1}, + [1]int32{}, + [1]int32{1}, + [2]int32{0, 1}, + [1]int64{}, + [1]int64{1}, + [2]int64{0, 1}, + [1]uint{}, + [1]uint{1}, + [2]uint{0, 1}, + [1]uintptr{}, + [1]uintptr{1}, + [2]uintptr{0, 1}, + [1]uint8{}, + [1]uint8{1}, + [2]uint8{0, 1}, + [1]uint16{}, + [1]uint16{1}, + [2]uint16{0, 1}, + [1]uint32{}, + [1]uint32{1}, + [2]uint32{0, 1}, + [1]uint64{}, + [1]uint64{1}, + [2]uint64{0, 1}, + [1]string{}, + [1]string{""}, + [1]string{nonEmptyString}, + [2]string{"", nonEmptyString}, +} + +func TestArrayPrimitives(t *testing.T) { + runTestCases(t, false, "plain", flatten(allArrayPrimitives)) + runTestCases(t, false, "pointers", pointersTo(flatten(allArrayPrimitives))) + runTestCases(t, false, "interfaces", interfacesTo(flatten(allArrayPrimitives))) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allArrayPrimitives)))) +} + +func TestSlices(t *testing.T) { + var allSlices = flatten( + filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) { + v := reflect.New(reflect.TypeOf(o)).Elem() + v.Set(reflect.ValueOf(o)) + return v.Slice(0, v.Len()).Interface(), true + }), + filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) { + v := reflect.New(reflect.TypeOf(o)).Elem() + v.Set(reflect.ValueOf(o)) + if v.Len() == 0 { + // Return the pure "nil" value for the slice. + return reflect.New(v.Slice(0, 0).Type()).Elem().Interface(), true + } + return v.Slice(1, v.Len()).Interface(), true + }), + filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) { + v := reflect.New(reflect.TypeOf(o)).Elem() + v.Set(reflect.ValueOf(o)) + if v.Len() == 0 { + // Return the zero-valued slice. + return reflect.MakeSlice(v.Slice(0, 0).Type(), 0, 0).Interface(), true + } + return v.Slice(0, v.Len()-1).Interface(), true + }), + ) + runTestCases(t, false, "plain", allSlices) + runTestCases(t, false, "pointers", pointersTo(allSlices)) + runTestCases(t, false, "interfaces", interfacesTo(allSlices)) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(allSlices))) +} + +func TestArrayContainers(t *testing.T) { + var ( + emptyArray [1]interface{} + fullArray [1]interface{} + ) + fullArray[0] = &emptyArray + runTestCases(t, false, "", []interface{}{ + arrayContainer{v: emptyArray}, + arrayContainer{v: fullArray}, + arrayPtrContainer{v: nil}, + arrayPtrContainer{v: &emptyArray}, + arrayPtrContainer{v: &fullArray}, + }) +} + +func TestSliceContainers(t *testing.T) { + var ( + nilSlice []interface{} + emptySlice = make([]interface{}, 0) + fullSlice = []interface{}{nil} + ) + runTestCases(t, false, "", []interface{}{ + sliceContainer{v: nilSlice}, + sliceContainer{v: emptySlice}, + sliceContainer{v: fullSlice}, + slicePtrContainer{v: nil}, + slicePtrContainer{v: &nilSlice}, + slicePtrContainer{v: &emptySlice}, + slicePtrContainer{v: &fullSlice}, + }) +} diff --git a/pkg/state/tests/bench.go b/pkg/state/tests/bench.go new file mode 100644 index 000000000..40869cdfb --- /dev/null +++ b/pkg/state/tests/bench.go @@ -0,0 +1,24 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +// +stateify savable +type benchStruct struct { + B *benchStruct // Must be exported for gob. +} + +func (b *benchStruct) afterLoad() { + // Do nothing, just force scheduling. +} diff --git a/pkg/state/tests/bench_test.go b/pkg/state/tests/bench_test.go new file mode 100644 index 000000000..7e102c907 --- /dev/null +++ b/pkg/state/tests/bench_test.go @@ -0,0 +1,153 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "bytes" + "context" + "encoding/gob" + "fmt" + "testing" + + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/wire" +) + +// buildPtrObject builds a benchmark object. +func buildPtrObject(n int) interface{} { + b := new(benchStruct) + for i := 0; i < n; i++ { + b = &benchStruct{B: b} + } + return b +} + +// buildMapObject builds a benchmark object. +func buildMapObject(n int) interface{} { + b := new(benchStruct) + m := make(map[int]*benchStruct) + for i := 0; i < n; i++ { + m[i] = b + } + return &m +} + +// buildSliceObject builds a benchmark object. +func buildSliceObject(n int) interface{} { + b := new(benchStruct) + s := make([]*benchStruct, 0, n) + for i := 0; i < n; i++ { + s = append(s, b) + } + return &s +} + +var allObjects = map[string]struct { + New func(int) interface{} +}{ + "ptr": { + New: buildPtrObject, + }, + "map": { + New: buildMapObject, + }, + "slice": { + New: buildSliceObject, + }, +} + +func buildObjects(n int, fn func(int) interface{}) (iters int, v interface{}) { + // maxSize is the maximum size of an individual object below. For an N + // larger than this, we start to return multiple objects. + const maxSize = 1024 + if n <= maxSize { + return 1, fn(n) + } + iters = (n + maxSize - 1) / maxSize + return iters, fn(maxSize) +} + +// gobSave is a version of save using gob (no stats available). +func gobSave(_ context.Context, w wire.Writer, v interface{}) (_ state.Stats, err error) { + enc := gob.NewEncoder(w) + err = enc.Encode(v) + return +} + +// gobLoad is a version of load using gob (no stats available). +func gobLoad(_ context.Context, r wire.Reader, v interface{}) (_ state.Stats, err error) { + dec := gob.NewDecoder(r) + err = dec.Decode(v) + return +} + +var allAlgos = map[string]struct { + Save func(context.Context, wire.Writer, interface{}) (state.Stats, error) + Load func(context.Context, wire.Reader, interface{}) (state.Stats, error) + MaxPtr int +}{ + "state": { + Save: state.Save, + Load: state.Load, + }, + "gob": { + Save: gobSave, + Load: gobLoad, + }, +} + +func BenchmarkEncoding(b *testing.B) { + for objName, objInfo := range allObjects { + for algoName, algoInfo := range allAlgos { + b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) { + b.StopTimer() + n, v := buildObjects(b.N, objInfo.New) + b.ReportAllocs() + b.StartTimer() + for i := 0; i < n; i++ { + if _, err := algoInfo.Save(context.Background(), discard{}, v); err != nil { + b.Errorf("save failed: %v", err) + } + } + b.StopTimer() + }) + } + } +} + +func BenchmarkDecoding(b *testing.B) { + for objName, objInfo := range allObjects { + for algoName, algoInfo := range allAlgos { + b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) { + b.StopTimer() + n, v := buildObjects(b.N, objInfo.New) + buf := new(bytes.Buffer) + if _, err := algoInfo.Save(context.Background(), buf, v); err != nil { + b.Errorf("save failed: %v", err) + } + b.ReportAllocs() + b.StartTimer() + var r bytes.Reader + for i := 0; i < n; i++ { + r.Reset(buf.Bytes()) + if _, err := algoInfo.Load(context.Background(), &r, v); err != nil { + b.Errorf("load failed: %v", err) + } + } + b.StopTimer() + }) + } + } +} diff --git a/pkg/state/tests/bool_test.go b/pkg/state/tests/bool_test.go new file mode 100644 index 000000000..e17cfacf9 --- /dev/null +++ b/pkg/state/tests/bool_test.go @@ -0,0 +1,31 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" +) + +var allBools = []bool{ + true, + false, +} + +func TestBool(t *testing.T) { + runTestCases(t, false, "plain", flatten(allBools)) + runTestCases(t, false, "pointers", pointersTo(flatten(allBools))) + runTestCases(t, false, "interfaces", interfacesTo(flatten(allBools))) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allBools)))) +} diff --git a/pkg/state/tests/float_test.go b/pkg/state/tests/float_test.go new file mode 100644 index 000000000..3e89edd9c --- /dev/null +++ b/pkg/state/tests/float_test.go @@ -0,0 +1,118 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "math" + "testing" +) + +var safeFloat32s = []float32{ + float32(0.0), + float32(1.0), + float32(-1.0), + float32(math.Inf(1)), + float32(math.Inf(-1)), +} + +var allFloat32s = append(safeFloat32s, float32(math.NaN())) + +var safeFloat64s = []float64{ + float64(0.0), + float64(1.0), + float64(-1.0), + math.Inf(1), + math.Inf(-1), +} + +var allFloat64s = append(safeFloat64s, math.NaN()) + +func TestFloat(t *testing.T) { + runTestCases(t, false, "plain", flatten( + allFloat32s, + allFloat64s, + )) + // See checkEqual for why NaNs are missing. + runTestCases(t, false, "pointers", pointersTo(flatten( + safeFloat32s, + safeFloat64s, + ))) + runTestCases(t, false, "interfaces", interfacesTo(flatten( + safeFloat32s, + safeFloat64s, + ))) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten( + safeFloat32s, + safeFloat64s, + )))) +} + +const onlyDouble float64 = 1.0000000000000002 + +func TestFloatTruncation(t *testing.T) { + runTestCases(t, true, "pass", []interface{}{ + truncatingFloat32{save: onlyDouble}, + }) + runTestCases(t, false, "fail", []interface{}{ + truncatingFloat32{save: 1.0}, + }) +} + +var safeComplex64s = combine(safeFloat32s, safeFloat32s, func(i, j interface{}) interface{} { + return complex(i.(float32), j.(float32)) +}) + +var allComplex64s = combine(allFloat32s, allFloat32s, func(i, j interface{}) interface{} { + return complex(i.(float32), j.(float32)) +}) + +var safeComplex128s = combine(safeFloat64s, safeFloat64s, func(i, j interface{}) interface{} { + return complex(i.(float64), j.(float64)) +}) + +var allComplex128s = combine(allFloat64s, allFloat64s, func(i, j interface{}) interface{} { + return complex(i.(float64), j.(float64)) +}) + +func TestComplex(t *testing.T) { + runTestCases(t, false, "plain", flatten( + allComplex64s, + allComplex128s, + )) + // See TestFloat; same issue. + runTestCases(t, false, "pointers", pointersTo(flatten( + safeComplex64s, + safeComplex128s, + ))) + runTestCases(t, false, "interfacse", interfacesTo(flatten( + safeComplex64s, + safeComplex128s, + ))) + runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(flatten( + safeComplex64s, + safeComplex128s, + )))) +} + +func TestComplexTruncation(t *testing.T) { + runTestCases(t, true, "pass", []interface{}{ + truncatingComplex64{save: complex(onlyDouble, onlyDouble)}, + truncatingComplex64{save: complex(1.0, onlyDouble)}, + truncatingComplex64{save: complex(onlyDouble, 1.0)}, + }) + runTestCases(t, false, "fail", []interface{}{ + truncatingComplex64{save: complex(1.0, 1.0)}, + }) +} diff --git a/pkg/state/tests/integer.go b/pkg/state/tests/integer.go new file mode 100644 index 000000000..ca403eed1 --- /dev/null +++ b/pkg/state/tests/integer.go @@ -0,0 +1,163 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +// +stateify type +type truncatingUint8 struct { + save uint64 + load uint8 `state:"nosave"` +} + +func (t *truncatingUint8) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingUint8) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = uint64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingUint8)(nil) + +// +stateify type +type truncatingUint16 struct { + save uint64 + load uint16 `state:"nosave"` +} + +func (t *truncatingUint16) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingUint16) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = uint64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingUint16)(nil) + +// +stateify type +type truncatingUint32 struct { + save uint64 + load uint32 `state:"nosave"` +} + +func (t *truncatingUint32) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingUint32) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = uint64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingUint32)(nil) + +// +stateify type +type truncatingInt8 struct { + save int64 + load int8 `state:"nosave"` +} + +func (t *truncatingInt8) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingInt8) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = int64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingInt8)(nil) + +// +stateify type +type truncatingInt16 struct { + save int64 + load int16 `state:"nosave"` +} + +func (t *truncatingInt16) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingInt16) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = int64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingInt16)(nil) + +// +stateify type +type truncatingInt32 struct { + save int64 + load int32 `state:"nosave"` +} + +func (t *truncatingInt32) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingInt32) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = int64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingInt32)(nil) + +// +stateify type +type truncatingFloat32 struct { + save float64 + load float32 `state:"nosave"` +} + +func (t *truncatingFloat32) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingFloat32) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = float64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingFloat32)(nil) + +// +stateify type +type truncatingComplex64 struct { + save complex128 + load complex64 `state:"nosave"` +} + +func (t *truncatingComplex64) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingComplex64) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = complex128(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingComplex64)(nil) diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go new file mode 100644 index 000000000..d3931c952 --- /dev/null +++ b/pkg/state/tests/integer_test.go @@ -0,0 +1,94 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "math" + "testing" +) + +var ( + allIntTs = []int{-1, 0, 1} + allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8} + allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16} + allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32} + allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64} + allUintTs = []uint{0, 1} + allUintptrs = []uintptr{0, 1, ^uintptr(0)} + allUint8s = []uint8{0, 1, math.MaxUint8} + allUint16s = []uint16{0, 1, math.MaxUint16} + allUint32s = []uint32{0, 1, math.MaxUint32} + allUint64s = []uint64{0, 1, math.MaxUint64} +) + +var allInts = flatten( + allIntTs, + allInt8s, + allInt16s, + allInt32s, + allInt64s, +) + +var allUints = flatten( + allUintTs, + allUintptrs, + allUint8s, + allUint16s, + allUint32s, + allUint64s, +) + +func TestInt(t *testing.T) { + runTestCases(t, false, "plain", allInts) + runTestCases(t, false, "pointers", pointersTo(allInts)) + runTestCases(t, false, "interfaces", interfacesTo(allInts)) + runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allInts))) +} + +func TestIntTruncation(t *testing.T) { + runTestCases(t, true, "pass", []interface{}{ + truncatingInt8{save: math.MinInt8 - 1}, + truncatingInt16{save: math.MinInt16 - 1}, + truncatingInt32{save: math.MinInt32 - 1}, + truncatingInt8{save: math.MaxInt8 + 1}, + truncatingInt16{save: math.MaxInt16 + 1}, + truncatingInt32{save: math.MaxInt32 + 1}, + }) + runTestCases(t, false, "fail", []interface{}{ + truncatingInt8{save: 1}, + truncatingInt16{save: 1}, + truncatingInt32{save: 1}, + }) +} + +func TestUint(t *testing.T) { + runTestCases(t, false, "plain", allUints) + runTestCases(t, false, "pointers", pointersTo(allUints)) + runTestCases(t, false, "interfaces", interfacesTo(allUints)) + runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allUints))) +} + +func TestUintTruncation(t *testing.T) { + runTestCases(t, true, "pass", []interface{}{ + truncatingUint8{save: math.MaxUint8 + 1}, + truncatingUint16{save: math.MaxUint16 + 1}, + truncatingUint32{save: math.MaxUint32 + 1}, + }) + runTestCases(t, false, "fail", []interface{}{ + truncatingUint8{save: 1}, + truncatingUint16{save: 1}, + truncatingUint32{save: 1}, + }) +} diff --git a/pkg/state/tests/load.go b/pkg/state/tests/load.go new file mode 100644 index 000000000..a8350c0f3 --- /dev/null +++ b/pkg/state/tests/load.go @@ -0,0 +1,61 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +// +stateify savable +type genericContainer struct { + v interface{} +} + +// +stateify savable +type afterLoadStruct struct { + v int `state:"nosave"` +} + +func (a *afterLoadStruct) afterLoad() { + a.v++ +} + +// +stateify savable +type valueLoadStruct struct { + v int `state:".(int64)"` +} + +func (v *valueLoadStruct) saveV() int64 { + return int64(v.v) // Save as int64. +} + +func (v *valueLoadStruct) loadV(value int64) { + v.v = int(value) // Load as int. +} + +// +stateify savable +type cycleStruct struct { + c *cycleStruct +} + +// +stateify savable +type badCycleStruct struct { + b *badCycleStruct `state:"wait"` +} + +func (b *badCycleStruct) afterLoad() { + if b.b != b { + // This is not executable, since AfterLoad requires that the + // object and all dependencies are complete. This should cause + // a deadlock error during load. + panic("badCycleStruct.afterLoad called") + } +} diff --git a/pkg/state/tests/load_test.go b/pkg/state/tests/load_test.go new file mode 100644 index 000000000..1e9794296 --- /dev/null +++ b/pkg/state/tests/load_test.go @@ -0,0 +1,70 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" +) + +func TestLoadHooks(t *testing.T) { + runTestCases(t, false, "load-hooks", []interface{}{ + &afterLoadStruct{v: 1}, + &valueLoadStruct{v: 1}, + &genericContainer{v: &afterLoadStruct{v: 1}}, + &genericContainer{v: &valueLoadStruct{v: 1}}, + &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}}, + &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}}, + &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}}, + &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}}, + }) +} + +func TestCycles(t *testing.T) { + // cs is a single object cycle. + cs := cycleStruct{nil} + cs.c = &cs + + // cs1 and cs2 are in a two object cycle. + cs1 := cycleStruct{nil} + cs2 := cycleStruct{nil} + cs1.c = &cs2 + cs2.c = &cs1 + + runTestCases(t, false, "cycles", []interface{}{ + cs, + cs1, + }) +} + +func TestDeadlock(t *testing.T) { + // bs is a single object cycle. This does not cause deadlock because an + // object cannot wait for itself. + bs := badCycleStruct{nil} + bs.b = &bs + + runTestCases(t, false, "self", []interface{}{ + &bs, + }) + + // bs2 and bs2 are in a deadlocking cycle. + bs1 := badCycleStruct{nil} + bs2 := badCycleStruct{nil} + bs1.b = &bs2 + bs2.b = &bs1 + + runTestCases(t, true, "deadlock", []interface{}{ + &bs1, + }) +} diff --git a/pkg/state/tests/map.go b/pkg/state/tests/map.go new file mode 100644 index 000000000..db4e548f1 --- /dev/null +++ b/pkg/state/tests/map.go @@ -0,0 +1,28 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +// +stateify savable +type mapContainer struct { + v map[int]interface{} +} + +// +stateify savable +type mapPtrContainer struct { + v *map[int]interface{} +} + +// +stateify savable +type registeredMapStruct struct{} diff --git a/pkg/state/tests/map_test.go b/pkg/state/tests/map_test.go new file mode 100644 index 000000000..92bf0fc01 --- /dev/null +++ b/pkg/state/tests/map_test.go @@ -0,0 +1,90 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "reflect" + "testing" +) + +var allMapPrimitives = []interface{}{ + bool(true), + int(1), + int8(1), + int16(1), + int32(1), + int64(1), + uint(1), + uintptr(1), + uint8(1), + uint16(1), + uint32(1), + uint64(1), + string(""), + registeredMapStruct{}, +} + +var allMapKeys = flatten(allMapPrimitives, pointersTo(allMapPrimitives)) + +var allMapValues = flatten(allMapPrimitives, pointersTo(allMapPrimitives), interfacesTo(allMapPrimitives)) + +var emptyMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} { + m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2))) + return m.Interface() +}) + +var fullMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} { + m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2))) + m.SetMapIndex(reflect.Zero(reflect.TypeOf(v1)), reflect.Zero(reflect.TypeOf(v2))) + return m.Interface() +}) + +func TestMapAliasing(t *testing.T) { + v := make(map[int]int) + ptrToV := &v + aliases := []map[int]int{v, v} + runTestCases(t, false, "", []interface{}{ptrToV, aliases}) +} + +func TestMapsEmpty(t *testing.T) { + runTestCases(t, false, "plain", emptyMaps) + runTestCases(t, false, "pointers", pointersTo(emptyMaps)) + runTestCases(t, false, "interfaces", interfacesTo(emptyMaps)) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(emptyMaps))) +} + +func TestMapsFull(t *testing.T) { + runTestCases(t, false, "plain", fullMaps) + runTestCases(t, false, "pointers", pointersTo(fullMaps)) + runTestCases(t, false, "interfaces", interfacesTo(fullMaps)) + runTestCases(t, false, "interfacesToPointer", interfacesTo(pointersTo(fullMaps))) +} + +func TestMapContainers(t *testing.T) { + var ( + nilMap map[int]interface{} + emptyMap = make(map[int]interface{}) + fullMap = map[int]interface{}{0: nil} + ) + runTestCases(t, false, "", []interface{}{ + mapContainer{v: nilMap}, + mapContainer{v: emptyMap}, + mapContainer{v: fullMap}, + mapPtrContainer{v: nil}, + mapPtrContainer{v: &nilMap}, + mapPtrContainer{v: &emptyMap}, + mapPtrContainer{v: &fullMap}, + }) +} diff --git a/pkg/state/tests/register.go b/pkg/state/tests/register.go new file mode 100644 index 000000000..074d86315 --- /dev/null +++ b/pkg/state/tests/register.go @@ -0,0 +1,21 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +// +stateify savable +type alreadyRegisteredStruct struct{} + +// +stateify savable +type alreadyRegisteredOther int diff --git a/pkg/state/tests/register_test.go b/pkg/state/tests/register_test.go new file mode 100644 index 000000000..c829753cc --- /dev/null +++ b/pkg/state/tests/register_test.go @@ -0,0 +1,167 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/state" +) + +// faker calls itself whatever is in the name field. +type faker struct { + Name string + Fields []string +} + +func (f *faker) StateTypeName() string { + return f.Name +} + +func (f *faker) StateFields() []string { + return f.Fields +} + +// fakerWithSaverLoader has all it needs. +type fakerWithSaverLoader struct { + faker +} + +func (f *fakerWithSaverLoader) StateSave(m state.Sink) {} + +func (f *fakerWithSaverLoader) StateLoad(m state.Source) {} + +// fakerOther calls itself .. uh, itself? +type fakerOther string + +func (f *fakerOther) StateTypeName() string { + return string(*f) +} + +func (f *fakerOther) StateFields() []string { + return nil +} + +func newFakerOther(name string) *fakerOther { + f := fakerOther(name) + return &f +} + +// fakerOtherBadFields returns non-nil fields. +type fakerOtherBadFields string + +func (f *fakerOtherBadFields) StateTypeName() string { + return string(*f) +} + +func (f *fakerOtherBadFields) StateFields() []string { + return []string{string(*f)} +} + +func newFakerOtherBadFields(name string) *fakerOtherBadFields { + f := fakerOtherBadFields(name) + return &f +} + +// fakerOtherSaverLoader implements SaverLoader methods. +type fakerOtherSaverLoader string + +func (f *fakerOtherSaverLoader) StateTypeName() string { + return string(*f) +} + +func (f *fakerOtherSaverLoader) StateFields() []string { + return nil +} + +func (f *fakerOtherSaverLoader) StateSave(m state.Sink) {} + +func (f *fakerOtherSaverLoader) StateLoad(m state.Source) {} + +func newFakerOtherSaverLoader(name string) *fakerOtherSaverLoader { + f := fakerOtherSaverLoader(name) + return &f +} + +func TestRegisterPrimitives(t *testing.T) { + for _, typeName := range []string{ + "int", + "int8", + "int16", + "int32", + "int64", + "uint", + "uintptr", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "complex64", + "complex128", + "string", + } { + t.Run("struct/"+typeName, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Registering type %q did not panic", typeName) + } + }() + state.Register(&faker{ + Name: typeName, + }) + }) + t.Run("other/"+typeName, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Registering type %q did not panic", typeName) + } + }() + state.Register(newFakerOther(typeName)) + }) + } +} + +func TestRegisterBad(t *testing.T) { + const ( + goodName = "foo" + firstField = "a" + secondField = "b" + ) + for name, object := range map[string]state.Type{ + "non-struct-with-fields": newFakerOtherBadFields(goodName), + "non-struct-with-saverloader": newFakerOtherSaverLoader(goodName), + "struct-without-saverloader": &faker{Name: goodName}, + "non-struct-duplicate-with-struct": newFakerOther((new(alreadyRegisteredStruct)).StateTypeName()), + "non-struct-duplicate-with-non-struct": newFakerOther((new(alreadyRegisteredOther)).StateTypeName()), + "struct-duplicate-with-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredStruct)).StateTypeName()}}, + "struct-duplicate-with-non-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredOther)).StateTypeName()}}, + "struct-with-empty-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{""}}}, + "struct-with-empty-field-and-non-empty": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, ""}}}, + "struct-with-duplicate-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, firstField}}}, + "struct-with-duplicate-field-and-non-dup": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, secondField, firstField}}}, + } { + t.Run(name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Registering object %#v did not panic", object) + } + }() + state.Register(object) + }) + + } +} diff --git a/pkg/state/tests/string_test.go b/pkg/state/tests/string_test.go new file mode 100644 index 000000000..44f5a562c --- /dev/null +++ b/pkg/state/tests/string_test.go @@ -0,0 +1,34 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" +) + +const nonEmptyString = "hello world" + +var allStrings = []string{ + "", + nonEmptyString, + "\\0", +} + +func TestString(t *testing.T) { + runTestCases(t, false, "plain", flatten(allStrings)) + runTestCases(t, false, "pointers", pointersTo(flatten(allStrings))) + runTestCases(t, false, "interfaces", interfacesTo(flatten(allStrings))) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allStrings)))) +} diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go new file mode 100644 index 000000000..bd2c2b399 --- /dev/null +++ b/pkg/state/tests/struct.go @@ -0,0 +1,65 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +type unregisteredEmptyStruct struct{} + +// typeOnlyEmptyStruct just implements the state.Type interface. +type typeOnlyEmptyStruct struct{} + +func (*typeOnlyEmptyStruct) StateTypeName() string { return "registeredEmptyStruct" } + +func (*typeOnlyEmptyStruct) StateFields() []string { return nil } + +// +stateify savable +type savableEmptyStruct struct{} + +// +stateify savable +type emptyStructPointer struct { + nothing *struct{} +} + +// +stateify savable +type outerSame struct { + inner inner +} + +// +stateify savable +type outerFieldFirst struct { + inner inner + v int64 +} + +// +stateify savable +type outerFieldSecond struct { + v int64 + inner inner +} + +// +stateify savable +type outerArray struct { + inner [2]inner +} + +// +stateify savable +type inner struct { + v int64 +} + +// +stateify savable +type system struct { + v1 interface{} + v2 interface{} +} diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go new file mode 100644 index 000000000..de9d17aa7 --- /dev/null +++ b/pkg/state/tests/struct_test.go @@ -0,0 +1,89 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/state" +) + +func TestEmptyStruct(t *testing.T) { + runTestCases(t, false, "plain", []interface{}{ + unregisteredEmptyStruct{}, + typeOnlyEmptyStruct{}, + savableEmptyStruct{}, + }) + runTestCases(t, false, "pointers", pointersTo([]interface{}{ + unregisteredEmptyStruct{}, + typeOnlyEmptyStruct{}, + savableEmptyStruct{}, + })) + runTestCases(t, false, "interfaces-pass", interfacesTo([]interface{}{ + // Only registered types can be dispatched via interfaces. All + // other types should fail, even if it is the empty struct. + savableEmptyStruct{}, + })) + runTestCases(t, true, "interfaces-fail", interfacesTo([]interface{}{ + unregisteredEmptyStruct{}, + typeOnlyEmptyStruct{}, + })) + runTestCases(t, false, "interfacesToPointers-pass", interfacesTo(pointersTo([]interface{}{ + savableEmptyStruct{}, + }))) + runTestCases(t, true, "interfacesToPointers-fail", interfacesTo(pointersTo([]interface{}{ + unregisteredEmptyStruct{}, + typeOnlyEmptyStruct{}, + }))) + + // Ensuring empty struct aliasing works. + es := emptyStructPointer{new(struct{})} + runTestCases(t, false, "empty-struct-pointers", []interface{}{ + emptyStructPointer{}, + es, + []emptyStructPointer{es, es}, // Same pointer. + }) +} + +func TestRegisterTypeOnlyStruct(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Register did not panic") + } + }() + state.Register((*typeOnlyEmptyStruct)(nil)) +} + +func TestEmbeddedPointers(t *testing.T) { + var ( + ofs outerSame + of1 outerFieldFirst + of2 outerFieldSecond + oa outerArray + ) + + runTestCases(t, false, "embedded-pointers", []interface{}{ + system{&ofs, &ofs.inner}, + system{&ofs.inner, &ofs}, + system{&of1, &of1.inner}, + system{&of1.inner, &of1}, + system{&of2, &of2.inner}, + system{&of2.inner, &of2}, + system{&oa, &oa.inner[0]}, + system{&oa, &oa.inner[1]}, + system{&oa.inner[0], &oa}, + system{&oa.inner[1], &oa}, + }) +} diff --git a/pkg/state/tests/tests.go b/pkg/state/tests/tests.go new file mode 100644 index 000000000..435a0e9db --- /dev/null +++ b/pkg/state/tests/tests.go @@ -0,0 +1,215 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tests tests the state packages. +package tests + +import ( + "bytes" + "context" + "fmt" + "math" + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/pretty" +) + +// discard is an implementation of wire.Writer. +type discard struct{} + +// Write implements wire.Writer.Write. +func (discard) Write(p []byte) (int, error) { return len(p), nil } + +// WriteByte implements wire.Writer.WriteByte. +func (discard) WriteByte(byte) error { return nil } + +// checkEqual checks if two objects are equal. +// +// N.B. This only handles one level of dereferences for NaN. Otherwise we +// would need to fork the entire implementation of reflect.DeepEqual. +func checkEqual(root, loadedValue interface{}) bool { + if reflect.DeepEqual(root, loadedValue) { + return true + } + + // NaN is not equal to itself. We handle the case of raw floating point + // primitives here, but don't handle this case nested. + rf32, ok1 := root.(float32) + lf32, ok2 := loadedValue.(float32) + if ok1 && ok2 && math.IsNaN(float64(rf32)) && math.IsNaN(float64(lf32)) { + return true + } + rf64, ok1 := root.(float64) + lf64, ok2 := loadedValue.(float64) + if ok1 && ok2 && math.IsNaN(rf64) && math.IsNaN(lf64) { + return true + } + + // Same real for complex numbers. + rc64, ok1 := root.(complex64) + lc64, ok2 := root.(complex64) + if ok1 && ok2 { + return checkEqual(real(rc64), real(lc64)) && checkEqual(imag(rc64), imag(lc64)) + } + rc128, ok1 := root.(complex128) + lc128, ok2 := root.(complex128) + if ok1 && ok2 { + return checkEqual(real(rc128), real(lc128)) && checkEqual(imag(rc128), imag(lc128)) + } + + return false +} + +// runTestCases runs a test for each object in objects. +func runTestCases(t *testing.T, shouldFail bool, prefix string, objects []interface{}) { + t.Helper() + for i, root := range objects { + t.Run(fmt.Sprintf("%s%d", prefix, i), func(t *testing.T) { + t.Logf("Original object:\n%#v", root) + + // Save the passed object. + saveBuffer := &bytes.Buffer{} + saveObjectPtr := reflect.New(reflect.TypeOf(root)) + saveObjectPtr.Elem().Set(reflect.ValueOf(root)) + saveStats, err := state.Save(context.Background(), saveBuffer, saveObjectPtr.Interface()) + if err != nil { + if shouldFail { + return + } + t.Fatalf("Save failed unexpectedly: %v", err) + } + + // Dump the serialized proto to aid with debugging. + var ppBuf bytes.Buffer + t.Logf("Raw state:\n%v", saveBuffer.Bytes()) + if err := pretty.PrintText(&ppBuf, bytes.NewReader(saveBuffer.Bytes())); err != nil { + // We don't count this as a test failure if we + // have shouldFail set, but we will count as a + // failure if we were not expecting to fail. + if !shouldFail { + t.Errorf("PrettyPrint(html=false) failed unexpected: %v", err) + } + } + if err := pretty.PrintHTML(discard{}, bytes.NewReader(saveBuffer.Bytes())); err != nil { + // See above. + if !shouldFail { + t.Errorf("PrettyPrint(html=true) failed unexpected: %v", err) + } + } + t.Logf("Encoded state:\n%s", ppBuf.String()) + t.Logf("Save stats:\n%s", saveStats.String()) + + // Load a new copy of the object. + loadObjectPtr := reflect.New(reflect.TypeOf(root)) + loadStats, err := state.Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface()) + if err != nil { + if shouldFail { + return + } + t.Fatalf("Load failed unexpectedly: %v", err) + } + + // Compare the values. + loadedValue := loadObjectPtr.Elem().Interface() + if !checkEqual(root, loadedValue) { + if shouldFail { + return + } + t.Fatalf("Objects differ:\n\toriginal: %#v\n\tloaded: %#v\n", root, loadedValue) + } + + // Everything went okay. Is that good? + if shouldFail { + t.Fatalf("This test was expected to fail, but didn't.") + } + t.Logf("Load stats:\n%s", loadStats.String()) + + // Truncate half the bytes in the byte stream, + // and ensure that we can't restore. Then + // truncate only the final byte and ensure that + // we can't restore. + l := saveBuffer.Len() + halfReader := bytes.NewReader(saveBuffer.Bytes()[:l/2]) + if _, err := state.Load(context.Background(), halfReader, loadObjectPtr.Interface()); err == nil { + t.Errorf("Load with half bytes succeeded unexpectedly.") + } + missingByteReader := bytes.NewReader(saveBuffer.Bytes()[:l-1]) + if _, err := state.Load(context.Background(), missingByteReader, loadObjectPtr.Interface()); err == nil { + t.Errorf("Load with missing byte succeeded unexpectedly.") + } + }) + } +} + +// convert converts the slice to an []interface{}. +func convert(v interface{}) (r []interface{}) { + s := reflect.ValueOf(v) // Must be slice. + for i := 0; i < s.Len(); i++ { + r = append(r, s.Index(i).Interface()) + } + return r +} + +// flatten flattens multiple slices. +func flatten(vs ...interface{}) (r []interface{}) { + for _, v := range vs { + r = append(r, convert(v)...) + } + return r +} + +// filter maps from one slice to another. +func filter(vs interface{}, fn func(interface{}) (interface{}, bool)) (r []interface{}) { + s := reflect.ValueOf(vs) + for i := 0; i < s.Len(); i++ { + v, ok := fn(s.Index(i).Interface()) + if ok { + r = append(r, v) + } + } + return r +} + +// combine combines objects in two slices as specified. +func combine(v1, v2 interface{}, fn func(_, _ interface{}) interface{}) (r []interface{}) { + s1 := reflect.ValueOf(v1) + s2 := reflect.ValueOf(v2) + for i := 0; i < s1.Len(); i++ { + for j := 0; j < s2.Len(); j++ { + // Combine using the given function. + r = append(r, fn(s1.Index(i).Interface(), s2.Index(j).Interface())) + } + } + return r +} + +// pointersTo is a filter function that returns pointers. +func pointersTo(vs interface{}) []interface{} { + return filter(vs, func(o interface{}) (interface{}, bool) { + v := reflect.New(reflect.TypeOf(o)) + v.Elem().Set(reflect.ValueOf(o)) + return v.Interface(), true + }) +} + +// interfacesTo is a filter function that returns interface objects. +func interfacesTo(vs interface{}) []interface{} { + return filter(vs, func(o interface{}) (interface{}, bool) { + var v [1]interface{} + v[0] = o + return v, true + }) +} diff --git a/pkg/state/types.go b/pkg/state/types.go new file mode 100644 index 000000000..215ef80f8 --- /dev/null +++ b/pkg/state/types.go @@ -0,0 +1,361 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "reflect" + "sort" + + "gvisor.dev/gvisor/pkg/state/wire" +) + +// assertValidType asserts that the type is valid. +func assertValidType(name string, fields []string) { + if name == "" { + Failf("type has empty name") + } + fieldsCopy := make([]string, len(fields)) + for i := 0; i < len(fields); i++ { + if fields[i] == "" { + Failf("field has empty name for type %q", name) + } + fieldsCopy[i] = fields[i] + } + sort.Slice(fieldsCopy, func(i, j int) bool { + return fieldsCopy[i] < fieldsCopy[j] + }) + for i := range fieldsCopy { + if i > 0 && fieldsCopy[i-1] == fieldsCopy[i] { + Failf("duplicate field %q for type %s", fieldsCopy[i], name) + } + } +} + +// typeEntry is an entry in the typeDatabase. +type typeEntry struct { + ID typeID + wire.Type +} + +// reconciledTypeEntry is a reconciled entry in the typeDatabase. +type reconciledTypeEntry struct { + wire.Type + LocalType reflect.Type + FieldOrder []int +} + +// typeEncodeDatabase is an internal TypeInfo database for encoding. +type typeEncodeDatabase struct { + // byType maps by type to the typeEntry. + byType map[reflect.Type]*typeEntry + + // lastID is the last used ID. + lastID typeID +} + +// makeTypeEncodeDatabase makes a typeDatabase. +func makeTypeEncodeDatabase() typeEncodeDatabase { + return typeEncodeDatabase{ + byType: make(map[reflect.Type]*typeEntry), + } +} + +// typeDecodeDatabase is an internal TypeInfo database for decoding. +type typeDecodeDatabase struct { + // byID maps by ID to type. + byID []*reconciledTypeEntry + + // pending are entries that are pending validation by Lookup. These + // will be reconciled with actual objects. Note that these will also be + // used to lookup types by name, since they may not be reconciled and + // there's little value to deleting from this map. + pending []*wire.Type +} + +// makeTypeDecodeDatabase makes a typeDatabase. +func makeTypeDecodeDatabase() typeDecodeDatabase { + return typeDecodeDatabase{} +} + +// lookupNameFields extracts the name and fields from an object. +func lookupNameFields(typ reflect.Type) (string, []string, bool) { + v := reflect.Zero(reflect.PtrTo(typ)).Interface() + t, ok := v.(Type) + if !ok { + // Is this a primitive? + if typ.Kind() == reflect.Interface { + return interfaceType, nil, true + } + name := typ.Name() + if _, ok := primitiveTypeDatabase[name]; !ok { + // This is not a known type, and not a primitive. The + // encoder may proceed for anonymous empty structs, or + // it may deference the type pointer and try again. + return "", nil, false + } + return name, nil, true + } + // Extract the name from the object. + name := t.StateTypeName() + fields := t.StateFields() + assertValidType(name, fields) + return name, fields, true +} + +// Lookup looks up or registers the given object. +// +// The bool indicates whether this is an existing entry: false means the entry +// did not exist, and true means the entry did exist. If this bool is false and +// the returned typeEntry are nil, then the obj did not implement the Type +// interface. +func (tdb *typeEncodeDatabase) Lookup(typ reflect.Type) (*typeEntry, bool) { + te, ok := tdb.byType[typ] + if !ok { + // Lookup the type information. + name, fields, ok := lookupNameFields(typ) + if !ok { + // Empty structs may still be encoded, so let the + // caller decide what to do from here. + return nil, false + } + + // Register the new type. + tdb.lastID++ + te = &typeEntry{ + ID: tdb.lastID, + Type: wire.Type{ + Name: name, + Fields: fields, + }, + } + + // All done. + tdb.byType[typ] = te + return te, false + } + return te, true +} + +// Register adds a typeID entry. +func (tbd *typeDecodeDatabase) Register(typ *wire.Type) { + assertValidType(typ.Name, typ.Fields) + tbd.pending = append(tbd.pending, typ) +} + +// LookupName looks up the type name by ID. +func (tbd *typeDecodeDatabase) LookupName(id typeID) string { + if len(tbd.pending) < int(id) { + // This is likely an encoder error? + Failf("type ID %d not available", id) + } + return tbd.pending[id-1].Name +} + +// LookupType looks up the type by ID. +func (tbd *typeDecodeDatabase) LookupType(id typeID) reflect.Type { + name := tbd.LookupName(id) + typ, ok := globalTypeDatabase[name] + if !ok { + // If not available, see if it's primitive. + typ, ok = primitiveTypeDatabase[name] + if !ok && name == interfaceType { + // Matches the built-in interface type. + var i interface{} + return reflect.TypeOf(&i).Elem() + } + if !ok { + // The type is perhaps not registered? + Failf("type name %q is not available", name) + } + return typ // Primitive type. + } + return typ // Registered type. +} + +// singleFieldOrder defines the field order for a single field. +var singleFieldOrder = []int{0} + +// Lookup looks up or registers the given object. +// +// First, the typeID is searched to see if this has already been appropriately +// reconciled. If no, then a reconcilation will take place that may result in a +// field ordering. If a nil reconciledTypeEntry is returned from this method, +// then the object does not support the Type interface. +// +// This method never returns nil. +func (tbd *typeDecodeDatabase) Lookup(id typeID, typ reflect.Type) *reconciledTypeEntry { + if len(tbd.byID) > int(id) && tbd.byID[id-1] != nil { + // Already reconciled. + return tbd.byID[id-1] + } + // The ID has not been reconciled yet. That's fine. We need to make + // sure it aligns with the current provided object. + if len(tbd.pending) < int(id) { + // This id was never registered. Probably an encoder error? + Failf("typeDatabase does not contain id %d", id) + } + // Extract the pending info. + pending := tbd.pending[id-1] + // Grow the byID list. + if len(tbd.byID) < int(id) { + tbd.byID = append(tbd.byID, make([]*reconciledTypeEntry, int(id)-len(tbd.byID))...) + } + // Reconcile the type. + name, fields, ok := lookupNameFields(typ) + if !ok { + // Empty structs are decoded only when the type is nil. Since + // this isn't the case, we fail here. + Failf("unsupported type %q during decode; can't reconcile", pending.Name) + } + if name != pending.Name { + // Are these the same type? Print a helpful message as this may + // actually happen in practice if types change. + Failf("typeDatabase contains conflicting definitions for id %d: %s->%v (current) and %s->%v (existing)", + id, name, fields, pending.Name, pending.Fields) + } + rte := &reconciledTypeEntry{ + Type: wire.Type{ + Name: name, + Fields: fields, + }, + LocalType: typ, + } + // If there are zero or one fields, then we skip allocating the field + // slice. There is special handling for decoding in this case. If the + // field name does not match, it will be caught in the general purpose + // code below. + if len(fields) != len(pending.Fields) { + Failf("type %q contains different fields: %v (decode) and %v (encode)", + name, fields, pending.Fields) + } + if len(fields) == 0 { + tbd.byID[id-1] = rte // Save. + return rte + } + if len(fields) == 1 && fields[0] == pending.Fields[0] { + tbd.byID[id-1] = rte // Save. + rte.FieldOrder = singleFieldOrder + return rte + } + // For each field in the current object's information, match it to a + // field in the destination object. We know from the assertion above + // and the insertion on insertion to pending that neither field + // contains any duplicates. + fieldOrder := make([]int, len(fields)) + for i, name := range fields { + fieldOrder[i] = -1 // Sentinel. + // Is it an exact match? + if pending.Fields[i] == name { + fieldOrder[i] = i + continue + } + // Find the matching field. + for j, otherName := range pending.Fields { + if name == otherName { + fieldOrder[i] = j + break + } + } + if fieldOrder[i] == -1 { + // The type name matches but we are lacking some common fields. + Failf("type %q has mismatched fields: %v (decode) and %v (encode)", + name, fields, pending.Fields) + } + } + // The type has been reeconciled. + rte.FieldOrder = fieldOrder + tbd.byID[id-1] = rte + return rte +} + +// interfaceType defines all interfaces. +const interfaceType = "interface" + +// primitiveTypeDatabase is a set of fixed types. +var primitiveTypeDatabase = func() map[string]reflect.Type { + r := make(map[string]reflect.Type) + for _, t := range []reflect.Type{ + reflect.TypeOf(false), + reflect.TypeOf(int(0)), + reflect.TypeOf(int8(0)), + reflect.TypeOf(int16(0)), + reflect.TypeOf(int32(0)), + reflect.TypeOf(int64(0)), + reflect.TypeOf(uint(0)), + reflect.TypeOf(uintptr(0)), + reflect.TypeOf(uint8(0)), + reflect.TypeOf(uint16(0)), + reflect.TypeOf(uint32(0)), + reflect.TypeOf(uint64(0)), + reflect.TypeOf(""), + reflect.TypeOf(float32(0.0)), + reflect.TypeOf(float64(0.0)), + reflect.TypeOf(complex64(0.0)), + reflect.TypeOf(complex128(0.0)), + } { + r[t.Name()] = t + } + return r +}() + +// globalTypeDatabase is used for dispatching interfaces on decode. +var globalTypeDatabase = map[string]reflect.Type{} + +// Register registers a type. +// +// This must be called on init and only done once. +func Register(t Type) { + name := t.StateTypeName() + fields := t.StateFields() + assertValidType(name, fields) + // Register must always be called on pointers. + typ := reflect.TypeOf(t) + if typ.Kind() != reflect.Ptr { + Failf("Register must be called on pointers") + } + typ = typ.Elem() + if typ.Kind() == reflect.Struct { + // All registered structs must implement SaverLoader. We allow + // the registration is non-struct types with just the Type + // interface, but we need to call StateSave/StateLoad methods + // on aggregate types. + if _, ok := t.(SaverLoader); !ok { + Failf("struct %T does not implement SaverLoader", t) + } + } else { + // Non-structs must not have any fields. We don't support + // calling StateSave/StateLoad methods on any non-struct types. + // If custom behavior is required, these types should be + // wrapped in a structure of some kind. + if len(fields) != 0 { + Failf("non-struct %T has non-zero fields %v", t, fields) + } + // We don't allow non-structs to implement StateSave/StateLoad + // methods, because they won't be called and it's confusing. + if _, ok := t.(SaverLoader); ok { + Failf("non-struct %T implements SaverLoader", t) + } + } + if _, ok := primitiveTypeDatabase[name]; ok { + Failf("conflicting primitiveTypeDatabase entry for %T: used by primitive", t) + } + if _, ok := globalTypeDatabase[name]; ok { + Failf("conflicting globalTypeDatabase entries for %T: name conflict", t) + } + if name == interfaceType { + Failf("conflicting name for %T: matches interfaceType", t) + } + globalTypeDatabase[name] = typ +} diff --git a/pkg/state/wire/BUILD b/pkg/state/wire/BUILD new file mode 100644 index 000000000..311b93dcb --- /dev/null +++ b/pkg/state/wire/BUILD @@ -0,0 +1,12 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "wire", + srcs = ["wire.go"], + marshal = False, + stateify = False, + visibility = ["//:sandbox"], + deps = ["//pkg/gohacks"], +) diff --git a/pkg/state/wire/wire.go b/pkg/state/wire/wire.go new file mode 100644 index 000000000..93dee6740 --- /dev/null +++ b/pkg/state/wire/wire.go @@ -0,0 +1,970 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package wire contains a few basic types that can be composed to serialize +// graph information for the state package. This package defines the wire +// protocol. +// +// Note that these types are careful about how they implement the relevant +// interfaces (either value receiver or pointer receiver), so that native-sized +// types, such as integers and simple pointers, can fit inside the interface +// object. +// +// This package also uses panic as control flow, so called should be careful to +// wrap calls in appropriate handlers. +// +// Testing for this package is driven by the state test package. +package wire + +import ( + "fmt" + "io" + "math" + + "gvisor.dev/gvisor/pkg/gohacks" +) + +// Reader is the required reader interface. +type Reader interface { + io.Reader + ReadByte() (byte, error) +} + +// Writer is the required writer interface. +type Writer interface { + io.Writer + WriteByte(byte) error +} + +// readFull is a utility. The equivalent is not needed for Write, but the API +// contract dictates that it must always complete all bytes given or return an +// error. +func readFull(r io.Reader, p []byte) { + for done := 0; done < len(p); { + n, err := r.Read(p[done:]) + done += n + if n == 0 && err != nil { + panic(err) + } + } +} + +// Object is a generic object. +type Object interface { + // save saves the given object. + // + // Panic is used for error control flow. + save(Writer) + + // load loads a new object of the given type. + // + // Panic is used for error control flow. + load(Reader) Object +} + +// Bool is a boolean. +type Bool bool + +// loadBool loads an object of type Bool. +func loadBool(r Reader) Bool { + b := loadUint(r) + return Bool(b == 1) +} + +// save implements Object.save. +func (b Bool) save(w Writer) { + var v Uint + if b { + v = 1 + } else { + v = 0 + } + v.save(w) +} + +// load implements Object.load. +func (Bool) load(r Reader) Object { return loadBool(r) } + +// Int is a signed integer. +// +// This uses varint encoding. +type Int int64 + +// loadInt loads an object of type Int. +func loadInt(r Reader) Int { + u := loadUint(r) + x := Int(u >> 1) + if u&1 != 0 { + x = ^x + } + return x +} + +// save implements Object.save. +func (i Int) save(w Writer) { + u := Uint(i) << 1 + if i < 0 { + u = ^u + } + u.save(w) +} + +// load implements Object.load. +func (Int) load(r Reader) Object { return loadInt(r) } + +// Uint is an unsigned integer. +type Uint uint64 + +// loadUint loads an object of type Uint. +func loadUint(r Reader) Uint { + var ( + u Uint + s uint + ) + for i := 0; i <= 9; i++ { + b, err := r.ReadByte() + if err != nil { + panic(err) + } + if b < 0x80 { + if i == 9 && b > 1 { + panic("overflow") + } + u |= Uint(b) << s + return u + } + u |= Uint(b&0x7f) << s + s += 7 + } + panic("unreachable") +} + +// save implements Object.save. +func (u Uint) save(w Writer) { + for u >= 0x80 { + if err := w.WriteByte(byte(u) | 0x80); err != nil { + panic(err) + } + u >>= 7 + } + if err := w.WriteByte(byte(u)); err != nil { + panic(err) + } +} + +// load implements Object.load. +func (Uint) load(r Reader) Object { return loadUint(r) } + +// Float32 is a 32-bit floating point number. +type Float32 float32 + +// loadFloat32 loads an object of type Float32. +func loadFloat32(r Reader) Float32 { + n := loadUint(r) + return Float32(math.Float32frombits(uint32(n))) +} + +// save implements Object.save. +func (f Float32) save(w Writer) { + n := Uint(math.Float32bits(float32(f))) + n.save(w) +} + +// load implements Object.load. +func (Float32) load(r Reader) Object { return loadFloat32(r) } + +// Float64 is a 64-bit floating point number. +type Float64 float64 + +// loadFloat64 loads an object of type Float64. +func loadFloat64(r Reader) Float64 { + n := loadUint(r) + return Float64(math.Float64frombits(uint64(n))) +} + +// save implements Object.save. +func (f Float64) save(w Writer) { + n := Uint(math.Float64bits(float64(f))) + n.save(w) +} + +// load implements Object.load. +func (Float64) load(r Reader) Object { return loadFloat64(r) } + +// Complex64 is a 64-bit complex number. +type Complex64 complex128 + +// loadComplex64 loads an object of type Complex64. +func loadComplex64(r Reader) Complex64 { + re := loadFloat32(r) + im := loadFloat32(r) + return Complex64(complex(float32(re), float32(im))) +} + +// save implements Object.save. +func (c *Complex64) save(w Writer) { + re := Float32(real(*c)) + im := Float32(imag(*c)) + re.save(w) + im.save(w) +} + +// load implements Object.load. +func (*Complex64) load(r Reader) Object { + c := loadComplex64(r) + return &c +} + +// Complex128 is a 128-bit complex number. +type Complex128 complex128 + +// loadComplex128 loads an object of type Complex128. +func loadComplex128(r Reader) Complex128 { + re := loadFloat64(r) + im := loadFloat64(r) + return Complex128(complex(float64(re), float64(im))) +} + +// save implements Object.save. +func (c *Complex128) save(w Writer) { + re := Float64(real(*c)) + im := Float64(imag(*c)) + re.save(w) + im.save(w) +} + +// load implements Object.load. +func (*Complex128) load(r Reader) Object { + c := loadComplex128(r) + return &c +} + +// String is a string. +type String string + +// loadString loads an object of type String. +func loadString(r Reader) String { + l := loadUint(r) + p := make([]byte, l) + readFull(r, p) + return String(gohacks.StringFromImmutableBytes(p)) +} + +// save implements Object.save. +func (s *String) save(w Writer) { + l := Uint(len(*s)) + l.save(w) + p := gohacks.ImmutableBytesFromString(string(*s)) + _, err := w.Write(p) // Must write all bytes. + if err != nil { + panic(err) + } +} + +// load implements Object.load. +func (*String) load(r Reader) Object { + s := loadString(r) + return &s +} + +// Dot is a kind of reference: one of Index and FieldName. +type Dot interface { + isDot() +} + +// Index is a reference resolution. +type Index uint32 + +func (Index) isDot() {} + +// FieldName is a reference resolution. +type FieldName string + +func (*FieldName) isDot() {} + +// Ref is a reference to an object. +type Ref struct { + // Root is the root object. + Root Uint + + // Dots is the set of traversals required from the Root object above. + // Note that this will be stored in reverse order for efficiency. + Dots []Dot + + // Type is the base type for the root object. This is non-nil iff Dots + // is non-zero length (that is, this is a complex reference). This is + // not *strictly* necessary, but can be used to simplify decoding. + Type TypeSpec +} + +// loadRef loads an object of type Ref (abstract). +func loadRef(r Reader) Ref { + ref := Ref{ + Root: loadUint(r), + } + l := loadUint(r) + ref.Dots = make([]Dot, l) + for i := 0; i < int(l); i++ { + // Disambiguate between an Index (non-negative) and a field + // name (negative). This does some space and avoids a dedicate + // loadDot function. See Ref.save for the other side. + d := loadInt(r) + if d >= 0 { + ref.Dots[i] = Index(d) + continue + } + p := make([]byte, -d) + readFull(r, p) + fieldName := FieldName(gohacks.StringFromImmutableBytes(p)) + ref.Dots[i] = &fieldName + } + if l != 0 { + // Only if dots is non-zero. + ref.Type = loadTypeSpec(r) + } + return ref +} + +// save implements Object.save. +func (r *Ref) save(w Writer) { + r.Root.save(w) + l := Uint(len(r.Dots)) + l.save(w) + for _, d := range r.Dots { + // See LoadRef. We use non-negative numbers to encode Index + // objects and negative numbers to encode field lengths. + switch x := d.(type) { + case Index: + i := Int(x) + i.save(w) + case *FieldName: + d := Int(-len(*x)) + d.save(w) + p := gohacks.ImmutableBytesFromString(string(*x)) + if _, err := w.Write(p); err != nil { + panic(err) + } + default: + panic("unknown dot implementation") + } + } + if l != 0 { + // See above. + saveTypeSpec(w, r.Type) + } +} + +// load implements Object.load. +func (*Ref) load(r Reader) Object { + ref := loadRef(r) + return &ref +} + +// Nil is a primitive zero value of any type. +type Nil struct{} + +// loadNil loads an object of type Nil. +func loadNil(r Reader) Nil { + return Nil{} +} + +// save implements Object.save. +func (Nil) save(w Writer) {} + +// load implements Object.load. +func (Nil) load(r Reader) Object { return loadNil(r) } + +// Slice is a slice value. +type Slice struct { + Length Uint + Capacity Uint + Ref Ref +} + +// loadSlice loads an object of type Slice. +func loadSlice(r Reader) Slice { + return Slice{ + Length: loadUint(r), + Capacity: loadUint(r), + Ref: loadRef(r), + } +} + +// save implements Object.save. +func (s *Slice) save(w Writer) { + s.Length.save(w) + s.Capacity.save(w) + s.Ref.save(w) +} + +// load implements Object.load. +func (*Slice) load(r Reader) Object { + s := loadSlice(r) + return &s +} + +// Array is an array value. +type Array struct { + Contents []Object +} + +// loadArray loads an object of type Array. +func loadArray(r Reader) Array { + l := loadUint(r) + if l == 0 { + // Note that there isn't a single object available to encode + // the type of, so we need this additional branch. + return Array{} + } + // All the objects here have the same type, so use dynamic dispatch + // only once. All other objects will automatically take the same type + // as the first object. + contents := make([]Object, l) + v := Load(r) + contents[0] = v + for i := 1; i < int(l); i++ { + contents[i] = v.load(r) + } + return Array{ + Contents: contents, + } +} + +// save implements Object.save. +func (a *Array) save(w Writer) { + l := Uint(len(a.Contents)) + l.save(w) + if l == 0 { + // See LoadArray. + return + } + // See above. + Save(w, a.Contents[0]) + for i := 1; i < int(l); i++ { + a.Contents[i].save(w) + } +} + +// load implements Object.load. +func (*Array) load(r Reader) Object { + a := loadArray(r) + return &a +} + +// Map is a map value. +type Map struct { + Keys []Object + Values []Object +} + +// loadMap loads an object of type Map. +func loadMap(r Reader) Map { + l := loadUint(r) + if l == 0 { + // See LoadArray. + return Map{} + } + // See type dispatch notes in Array. + keys := make([]Object, l) + values := make([]Object, l) + k := Load(r) + v := Load(r) + keys[0] = k + values[0] = v + for i := 1; i < int(l); i++ { + keys[i] = k.load(r) + values[i] = v.load(r) + } + return Map{ + Keys: keys, + Values: values, + } +} + +// save implements Object.save. +func (m *Map) save(w Writer) { + l := Uint(len(m.Keys)) + if int(l) != len(m.Values) { + panic(fmt.Sprintf("mismatched keys (%d) Aand values (%d)", len(m.Keys), len(m.Values))) + } + l.save(w) + if l == 0 { + // See LoadArray. + return + } + // See above. + Save(w, m.Keys[0]) + Save(w, m.Values[0]) + for i := 1; i < int(l); i++ { + m.Keys[i].save(w) + m.Values[i].save(w) + } +} + +// load implements Object.load. +func (*Map) load(r Reader) Object { + m := loadMap(r) + return &m +} + +// TypeSpec is a type dereference. +type TypeSpec interface { + isTypeSpec() +} + +// TypeID is a concrete type ID. +type TypeID Uint + +func (TypeID) isTypeSpec() {} + +// TypeSpecPointer is a pointer type. +type TypeSpecPointer struct { + Type TypeSpec +} + +func (*TypeSpecPointer) isTypeSpec() {} + +// TypeSpecArray is an array type. +type TypeSpecArray struct { + Count Uint + Type TypeSpec +} + +func (*TypeSpecArray) isTypeSpec() {} + +// TypeSpecSlice is a slice type. +type TypeSpecSlice struct { + Type TypeSpec +} + +func (*TypeSpecSlice) isTypeSpec() {} + +// TypeSpecMap is a map type. +type TypeSpecMap struct { + Key TypeSpec + Value TypeSpec +} + +func (*TypeSpecMap) isTypeSpec() {} + +// TypeSpecNil is an empty type. +type TypeSpecNil struct{} + +func (TypeSpecNil) isTypeSpec() {} + +// TypeSpec types. +// +// These use a distinct encoding on the wire, as they are used only in the +// interface object. They are decoded through the dedicated loadTypeSpec and +// saveTypeSpec functions. +const ( + typeSpecTypeID Uint = iota + typeSpecPointer + typeSpecArray + typeSpecSlice + typeSpecMap + typeSpecNil +) + +// loadTypeSpec loads TypeSpec values. +func loadTypeSpec(r Reader) TypeSpec { + switch hdr := loadUint(r); hdr { + case typeSpecTypeID: + return TypeID(loadUint(r)) + case typeSpecPointer: + return &TypeSpecPointer{ + Type: loadTypeSpec(r), + } + case typeSpecArray: + return &TypeSpecArray{ + Count: loadUint(r), + Type: loadTypeSpec(r), + } + case typeSpecSlice: + return &TypeSpecSlice{ + Type: loadTypeSpec(r), + } + case typeSpecMap: + return &TypeSpecMap{ + Key: loadTypeSpec(r), + Value: loadTypeSpec(r), + } + case typeSpecNil: + return TypeSpecNil{} + default: + // This is not a valid stream? + panic(fmt.Errorf("unknown header: %d", hdr)) + } +} + +// saveTypeSpec saves TypeSpec values. +func saveTypeSpec(w Writer, t TypeSpec) { + switch x := t.(type) { + case TypeID: + typeSpecTypeID.save(w) + Uint(x).save(w) + case *TypeSpecPointer: + typeSpecPointer.save(w) + saveTypeSpec(w, x.Type) + case *TypeSpecArray: + typeSpecArray.save(w) + x.Count.save(w) + saveTypeSpec(w, x.Type) + case *TypeSpecSlice: + typeSpecSlice.save(w) + saveTypeSpec(w, x.Type) + case *TypeSpecMap: + typeSpecMap.save(w) + saveTypeSpec(w, x.Key) + saveTypeSpec(w, x.Value) + case TypeSpecNil: + typeSpecNil.save(w) + default: + // This should not happen? + panic(fmt.Errorf("unknown type %T", t)) + } +} + +// Interface is an interface value. +type Interface struct { + Type TypeSpec + Value Object +} + +// loadInterface loads an object of type Interface. +func loadInterface(r Reader) Interface { + return Interface{ + Type: loadTypeSpec(r), + Value: Load(r), + } +} + +// save implements Object.save. +func (i *Interface) save(w Writer) { + saveTypeSpec(w, i.Type) + Save(w, i.Value) +} + +// load implements Object.load. +func (*Interface) load(r Reader) Object { + i := loadInterface(r) + return &i +} + +// Type is type information. +type Type struct { + Name string + Fields []string +} + +// loadType loads an object of type Type. +func loadType(r Reader) Type { + name := string(loadString(r)) + l := loadUint(r) + fields := make([]string, l) + for i := 0; i < int(l); i++ { + fields[i] = string(loadString(r)) + } + return Type{ + Name: name, + Fields: fields, + } +} + +// save implements Object.save. +func (t *Type) save(w Writer) { + s := String(t.Name) + s.save(w) + l := Uint(len(t.Fields)) + l.save(w) + for i := 0; i < int(l); i++ { + s := String(t.Fields[i]) + s.save(w) + } +} + +// load implements Object.load. +func (*Type) load(r Reader) Object { + t := loadType(r) + return &t +} + +// multipleObjects is a special type for serializing multiple objects. +type multipleObjects []Object + +// loadMultipleObjects loads a series of objects. +func loadMultipleObjects(r Reader) multipleObjects { + l := loadUint(r) + m := make(multipleObjects, l) + for i := 0; i < int(l); i++ { + m[i] = Load(r) + } + return m +} + +// save implements Object.save. +func (m *multipleObjects) save(w Writer) { + l := Uint(len(*m)) + l.save(w) + for i := 0; i < int(l); i++ { + Save(w, (*m)[i]) + } +} + +// load implements Object.load. +func (*multipleObjects) load(r Reader) Object { + m := loadMultipleObjects(r) + return &m +} + +// noObjects represents no objects. +type noObjects struct{} + +// loadNoObjects loads a sentinel. +func loadNoObjects(r Reader) noObjects { return noObjects{} } + +// save implements Object.save. +func (noObjects) save(w Writer) {} + +// load implements Object.load. +func (noObjects) load(r Reader) Object { return loadNoObjects(r) } + +// Struct is a basic composite value. +type Struct struct { + TypeID TypeID + fields Object // Optionally noObjects or *multipleObjects. +} + +// Field returns a pointer to the given field slot. +// +// This must be called after Alloc. +func (s *Struct) Field(i int) *Object { + if fields, ok := s.fields.(*multipleObjects); ok { + return &((*fields)[i]) + } + if _, ok := s.fields.(noObjects); ok { + // Alloc may be optionally called; can't call twice. + panic("Field called inappropriately, wrong Alloc?") + } + return &s.fields +} + +// Alloc allocates the given number of fields. +// +// This must be called before Add and Save. +// +// Precondition: slots must be positive. +func (s *Struct) Alloc(slots int) { + switch { + case slots == 0: + s.fields = noObjects{} + case slots == 1: + // Leave it alone. + case slots > 1: + fields := make(multipleObjects, slots) + s.fields = &fields + default: + // Violates precondition. + panic(fmt.Sprintf("Alloc called with negative slots %d?", slots)) + } +} + +// Fields returns the number of fields. +func (s *Struct) Fields() int { + switch x := s.fields.(type) { + case *multipleObjects: + return len(*x) + case noObjects: + return 0 + default: + return 1 + } +} + +// loadStruct loads an object of type Struct. +func loadStruct(r Reader) Struct { + return Struct{ + TypeID: TypeID(loadUint(r)), + fields: Load(r), + } +} + +// save implements Object.save. +// +// Precondition: Alloc must have been called, and the fields all filled in +// appropriately. See Alloc and Add for more details. +func (s *Struct) save(w Writer) { + Uint(s.TypeID).save(w) + Save(w, s.fields) +} + +// load implements Object.load. +func (*Struct) load(r Reader) Object { + s := loadStruct(r) + return &s +} + +// Object types. +// +// N.B. Be careful about changing the order or introducing new elements in the +// middle here. This is part of the wire format and shouldn't change. +const ( + typeBool Uint = iota + typeInt + typeUint + typeFloat32 + typeFloat64 + typeNil + typeRef + typeString + typeSlice + typeArray + typeMap + typeStruct + typeNoObjects + typeMultipleObjects + typeInterface + typeComplex64 + typeComplex128 + typeType +) + +// Save saves the given object. +// +// +checkescape all +// +// N.B. This function will panic on error. +func Save(w Writer, obj Object) { + switch x := obj.(type) { + case Bool: + typeBool.save(w) + x.save(w) + case Int: + typeInt.save(w) + x.save(w) + case Uint: + typeUint.save(w) + x.save(w) + case Float32: + typeFloat32.save(w) + x.save(w) + case Float64: + typeFloat64.save(w) + x.save(w) + case Nil: + typeNil.save(w) + x.save(w) + case *Ref: + typeRef.save(w) + x.save(w) + case *String: + typeString.save(w) + x.save(w) + case *Slice: + typeSlice.save(w) + x.save(w) + case *Array: + typeArray.save(w) + x.save(w) + case *Map: + typeMap.save(w) + x.save(w) + case *Struct: + typeStruct.save(w) + x.save(w) + case noObjects: + typeNoObjects.save(w) + x.save(w) + case *multipleObjects: + typeMultipleObjects.save(w) + x.save(w) + case *Interface: + typeInterface.save(w) + x.save(w) + case *Type: + typeType.save(w) + x.save(w) + case *Complex64: + typeComplex64.save(w) + x.save(w) + case *Complex128: + typeComplex128.save(w) + x.save(w) + default: + panic(fmt.Errorf("unknown type: %#v", obj)) + } +} + +// Load loads a new object. +// +// +checkescape all +// +// N.B. This function will panic on error. +func Load(r Reader) Object { + switch hdr := loadUint(r); hdr { + case typeBool: + return loadBool(r) + case typeInt: + return loadInt(r) + case typeUint: + return loadUint(r) + case typeFloat32: + return loadFloat32(r) + case typeFloat64: + return loadFloat64(r) + case typeNil: + return loadNil(r) + case typeRef: + return ((*Ref)(nil)).load(r) // Escapes. + case typeString: + return ((*String)(nil)).load(r) // Escapes. + case typeSlice: + return ((*Slice)(nil)).load(r) // Escapes. + case typeArray: + return ((*Array)(nil)).load(r) // Escapes. + case typeMap: + return ((*Map)(nil)).load(r) // Escapes. + case typeStruct: + return ((*Struct)(nil)).load(r) // Escapes. + case typeNoObjects: // Special for struct. + return loadNoObjects(r) + case typeMultipleObjects: // Special for struct. + return ((*multipleObjects)(nil)).load(r) // Escapes. + case typeInterface: + return ((*Interface)(nil)).load(r) // Escapes. + case typeComplex64: + return ((*Complex64)(nil)).load(r) // Escapes. + case typeComplex128: + return ((*Complex128)(nil)).load(r) // Escapes. + case typeType: + return ((*Type)(nil)).load(r) // Escapes. + default: + // This is not a valid stream? + panic(fmt.Errorf("unknown header: %d", hdr)) + } +} + +// LoadUint loads a single unsigned integer. +// +// N.B. This function will panic on error. +func LoadUint(r Reader) uint64 { + return uint64(loadUint(r)) +} + +// SaveUint saves a single unsigned integer. +// +// N.B. This function will panic on error. +func SaveUint(w Writer, v uint64) { + Uint(v).save(w) +} |