summaryrefslogtreecommitdiffhomepage
path: root/pkg/state
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/state')
-rw-r--r--pkg/state/BUILD77
-rw-r--r--pkg/state/decode.go594
-rw-r--r--pkg/state/encode.go454
-rw-r--r--pkg/state/encode_unsafe.go81
-rw-r--r--pkg/state/map.go221
-rw-r--r--pkg/state/object.proto140
-rw-r--r--pkg/state/printer.go188
-rw-r--r--pkg/state/state.go349
-rw-r--r--pkg/state/state_test.go719
-rw-r--r--pkg/state/statefile/BUILD23
-rw-r--r--pkg/state/statefile/statefile.go233
-rw-r--r--pkg/state/statefile/statefile_test.go299
-rw-r--r--pkg/state/stats.go133
13 files changed, 3511 insertions, 0 deletions
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
new file mode 100644
index 000000000..bb6415d9b
--- /dev/null
+++ b/pkg/state/BUILD
@@ -0,0 +1,77 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+go_template_instance(
+ name = "addr_range",
+ out = "addr_range.go",
+ package = "state",
+ prefix = "addr",
+ template = "//pkg/segment:generic_range",
+ types = {
+ "T": "uintptr",
+ },
+)
+
+go_template_instance(
+ name = "addr_set",
+ out = "addr_set.go",
+ consts = {
+ "minDegree": "10",
+ },
+ imports = {
+ "reflect": "reflect",
+ },
+ package = "state",
+ prefix = "addr",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uintptr",
+ "Range": "addrRange",
+ "Value": "reflect.Value",
+ "Functions": "addrSetFunctions",
+ },
+)
+
+go_library(
+ name = "state",
+ srcs = [
+ "addr_range.go",
+ "addr_set.go",
+ "decode.go",
+ "encode.go",
+ "encode_unsafe.go",
+ "map.go",
+ "printer.go",
+ "state.go",
+ "stats.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/state",
+ visibility = ["//:sandbox"],
+ deps = [
+ ":object_go_proto",
+ "@com_github_golang_protobuf//proto:go_default_library",
+ ],
+)
+
+proto_library(
+ name = "object_proto",
+ srcs = ["object.proto"],
+ visibility = ["//:sandbox"],
+)
+
+go_proto_library(
+ name = "object_go_proto",
+ importpath = "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto",
+ proto = ":object_proto",
+ visibility = ["//:sandbox"],
+)
+
+go_test(
+ name = "state_test",
+ size = "small",
+ srcs = ["state_test.go"],
+ embed = [":state"],
+)
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
new file mode 100644
index 000000000..05758495b
--- /dev/null
+++ b/pkg/state/decode.go
@@ -0,0 +1,594 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "reflect"
+ "sort"
+
+ "github.com/golang/protobuf/proto"
+ pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto"
+)
+
+// objectState represents an object that may be in the process of being
+// decoded. Specifically, it represents either a decoded object, or an an
+// interest in a future object that will be decoded. When that interest is
+// registered (via register), the storage for the object will be created, but
+// it will not be decoded until the object is encountered in the stream.
+type objectState struct {
+ // id is the id for this object.
+ //
+ // If this field is zero, then this is an anonymous (unregistered,
+ // non-reference primitive) object. This is immutable.
+ id uint64
+
+ // obj is the object. This may or may not be valid yet, depending on
+ // whether complete returns true. However, regardless of whether the
+ // object is valid, obj contains a final storage location for the
+ // object. This is immutable.
+ //
+ // Note that this must be addressable (obj.Addr() must not panic).
+ //
+ // The obj passed to the decode methods below will equal this obj only
+ // in the case of decoding the top-level object. However, the passed
+ // obj may represent individual fields, elements of a slice, etc. that
+ // are effectively embedded within the reflect.Value below but with
+ // distinct types.
+ obj reflect.Value
+
+ // blockedBy is the number of dependencies this object has.
+ blockedBy int
+
+ // blocking is a list of the objects blocked by this one.
+ blocking []*objectState
+
+ // callbacks is a set of callbacks to execute on load.
+ callbacks []func()
+
+ // path is the decoding path to the object.
+ path recoverable
+}
+
+// complete indicates the object is complete.
+func (os *objectState) complete() bool {
+ return os.blockedBy == 0 && len(os.callbacks) == 0
+}
+
+// checkComplete checks for completion. If the object is complete, pending
+// callbacks will be executed and checkComplete will be called on downstream
+// objects (those depending on this one).
+func (os *objectState) checkComplete(stats *Stats) {
+ if os.blockedBy > 0 {
+ return
+ }
+
+ // Fire all callbacks.
+ for _, fn := range os.callbacks {
+ stats.Start(os.obj)
+ fn()
+ stats.Done()
+ }
+ os.callbacks = nil
+
+ // Clear all blocked objects.
+ for _, other := range os.blocking {
+ other.blockedBy--
+ other.checkComplete(stats)
+ }
+ os.blocking = nil
+}
+
+// waitFor queues a dependency on the given object.
+func (os *objectState) waitFor(other *objectState, callback func()) {
+ os.blockedBy++
+ other.blocking = append(other.blocking, os)
+ if callback != nil {
+ other.callbacks = append(other.callbacks, callback)
+ }
+}
+
+// findCycleFor returns when the given object is found in the blocking set.
+func (os *objectState) findCycleFor(target *objectState) []*objectState {
+ for _, other := range os.blocking {
+ if other == target {
+ return []*objectState{target}
+ } else if childList := other.findCycleFor(target); childList != nil {
+ return append(childList, other)
+ }
+ }
+ return nil
+}
+
+// findCycle finds a dependency cycle.
+func (os *objectState) findCycle() []*objectState {
+ return append(os.findCycleFor(os), os)
+}
+
+// decodeState is a graph of objects in the process of being decoded.
+//
+// The decode process involves loading the breadth-first graph generated by
+// encode. This graph is read in it's entirety, ensuring that all object
+// storage is complete.
+//
+// As the graph is being serialized, a set of completion callbacks are
+// executed. These completion callbacks should form a set of acyclic subgraphs
+// over the original one. After decoding is complete, the objects are scanned
+// to ensure that all callbacks are executed, otherwise the callback graph was
+// not acyclic.
+type decodeState struct {
+ // objectByID is the set of objects in progress.
+ objectsByID map[uint64]*objectState
+
+ // deferred are objects that have been read, by no interest has been
+ // registered yet. These will be decoded once interest in registered.
+ deferred map[uint64]*pb.Object
+
+ // outstanding is the number of outstanding objects.
+ outstanding uint32
+
+ // r is the input stream.
+ r io.Reader
+
+ // stats is the passed stats object.
+ stats *Stats
+
+ // recoverable is the panic recover facility.
+ recoverable
+}
+
+// lookup looks up an object in decodeState or returns nil if no such object
+// has been previously registered.
+func (ds *decodeState) lookup(id uint64) *objectState {
+ return ds.objectsByID[id]
+}
+
+// wait registers a dependency on an object.
+//
+// As a special case, we always allow _useable_ references back to the first
+// decoding object because it may have fields that are already decoded. We also
+// allow trivial self reference, since they can be handled internally.
+func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) {
+ switch id {
+ case 0:
+ // Nil pointer; nothing to wait for.
+ fallthrough
+ case waiter.id:
+ // Trivial self reference.
+ fallthrough
+ case 1:
+ // Root object; see above.
+ if callback != nil {
+ callback()
+ }
+ return
+ }
+
+ // No nil can be returned here.
+ waiter.waitFor(ds.lookup(id), callback)
+}
+
+// waitObject notes a blocking relationship.
+func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) {
+ if rv, ok := p.Value.(*pb.Object_RefValue); ok {
+ // Refs can encode pointers and maps.
+ ds.wait(os, rv.RefValue, callback)
+ } else if sv, ok := p.Value.(*pb.Object_SliceValue); ok {
+ // See decodeObject; we need to wait for the array (if non-nil).
+ ds.wait(os, sv.SliceValue.RefValue, callback)
+ } else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok {
+ // It's an interface (wait recurisvely).
+ ds.waitObject(os, iv.InterfaceValue.Value, callback)
+ } else if callback != nil {
+ // Nothing to wait for: execute the callback immediately.
+ callback()
+ }
+}
+
+// register registers a decode with a type.
+//
+// This type is only used to instantiate a new object if it has not been
+// registered previously.
+func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState {
+ os, ok := ds.objectsByID[id]
+ if ok {
+ return os
+ }
+
+ // Record in the object index.
+ if typ.Kind() == reflect.Map {
+ os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()}
+ } else {
+ os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()}
+ }
+ ds.objectsByID[id] = os
+
+ if o, ok := ds.deferred[id]; ok {
+ // There is a deferred object.
+ delete(ds.deferred, id) // Free memory.
+ ds.decodeObject(os, os.obj, o, "", nil)
+ } else {
+ // There is no deferred object.
+ ds.outstanding++
+ }
+
+ return os
+}
+
+// decodeStruct decodes a struct value.
+func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) {
+ // Set the fields.
+ m := Map{newInternalMap(nil, ds, os)}
+ defer internalMapPool.Put(m.internalMap)
+ for _, field := range s.Fields {
+ m.data = append(m.data, entry{
+ name: field.Name,
+ object: field.Value,
+ })
+ }
+
+ // Sort the fields for efficient searching.
+ //
+ // Technically, these should already appear in sorted order in the
+ // state ordering, so this cost is effectively a single scan to ensure
+ // that the order is correct.
+ if len(m.data) > 1 {
+ sort.Slice(m.data, func(i, j int) bool {
+ return m.data[i].name < m.data[j].name
+ })
+ }
+
+ // Invoke the load; this will recursively decode other objects.
+ fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
+ if ok {
+ // Invoke the loader.
+ fns.invokeLoad(obj.Addr(), m)
+ } else if obj.NumField() == 0 {
+ // Allow anonymous empty structs.
+ return
+ } else {
+ // Propagate an error.
+ panic(fmt.Errorf("unregistered type %s", obj.Type()))
+ }
+}
+
+// decodeMap decodes a map value.
+func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) {
+ if obj.IsNil() {
+ obj.Set(reflect.MakeMap(obj.Type()))
+ }
+ for i := 0; i < len(m.Keys); i++ {
+ // Decode the objects.
+ kv := reflect.New(obj.Type().Key()).Elem()
+ vv := reflect.New(obj.Type().Elem()).Elem()
+ ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i)
+ ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface())
+ ds.waitObject(os, m.Keys[i], nil)
+ ds.waitObject(os, m.Values[i], nil)
+
+ // Set in the map.
+ obj.SetMapIndex(kv, vv)
+ }
+}
+
+// decodeArray decodes an array value.
+func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) {
+ if len(a.Contents) != obj.Len() {
+ panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents)))
+ }
+ // Decode the contents into the array.
+ for i := 0; i < len(a.Contents); i++ {
+ ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i)
+ ds.waitObject(os, a.Contents[i], nil)
+ }
+}
+
+// decodeInterface decodes an interface value.
+func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) {
+ // Is this a nil value?
+ if i.Type == "" {
+ return // Just leave obj alone.
+ }
+
+ // Get the dispatchable type. This may not be used if the given
+ // reference has already been resolved, but if not we need to know the
+ // type to create.
+ t, ok := registeredTypes.lookupType(i.Type)
+ if !ok {
+ panic(fmt.Errorf("no valid type for %q", i.Type))
+ }
+
+ if obj.Kind() != reflect.Map {
+ // Set the obj to be the given typed value; this actually sets
+ // obj to be a non-zero value -- namely, it inserts type
+ // information. There's no need to do this for maps.
+ obj.Set(reflect.Zero(t))
+ }
+
+ // Decode the dereferenced element; there is no need to wait here, as
+ // the interface object shares the current object state.
+ ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type)
+}
+
+// decodeObject decodes a object value.
+func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) {
+ ds.push(false, format, param)
+ ds.stats.Start(obj)
+
+ switch x := object.GetValue().(type) {
+ case *pb.Object_BoolValue:
+ obj.SetBool(x.BoolValue)
+ case *pb.Object_StringValue:
+ obj.SetString(x.StringValue)
+ case *pb.Object_Int64Value:
+ obj.SetInt(x.Int64Value)
+ if obj.Int() != x.Int64Value {
+ panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type()))
+ }
+ case *pb.Object_Uint64Value:
+ obj.SetUint(x.Uint64Value)
+ if obj.Uint() != x.Uint64Value {
+ panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type()))
+ }
+ case *pb.Object_DoubleValue:
+ obj.SetFloat(x.DoubleValue)
+ if obj.Float() != x.DoubleValue {
+ panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type()))
+ }
+ case *pb.Object_RefValue:
+ // Resolve the pointer itself, even though the object may not
+ // be decoded yet. You need to use wait() in order to ensure
+ // that is the case. See wait above, and Map.Barrier.
+ if id := x.RefValue; id != 0 {
+ // Decoding the interface should have imparted type
+ // information, so from this point it's safe to resolve
+ // and use this dynamic information for actually
+ // creating the object in register.
+ //
+ // (For non-interfaces this is a no-op).
+ dyntyp := reflect.TypeOf(obj.Interface())
+ if dyntyp.Kind() == reflect.Map {
+ obj.Set(ds.register(id, dyntyp).obj)
+ } else if dyntyp.Kind() == reflect.Ptr {
+ ds.push(true /* dereference */, "", nil)
+ obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
+ ds.pop()
+ } else {
+ obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
+ }
+ } else {
+ // We leave obj alone here. That's because if obj
+ // represents an interface, it may have been embued
+ // with type information in decodeInterface, and we
+ // don't want to destroy that information.
+ }
+ case *pb.Object_SliceValue:
+ // It's okay to slice the array here, since the contents will
+ // still be provided later on. These semantics are a bit
+ // strange but they are handled in the Map.Barrier properly.
+ //
+ // The special semantics of zero ref apply here too.
+ if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 {
+ v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem())
+ obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity)))
+ }
+ case *pb.Object_ArrayValue:
+ ds.decodeArray(os, obj, x.ArrayValue)
+ case *pb.Object_StructValue:
+ ds.decodeStruct(os, obj, x.StructValue)
+ case *pb.Object_MapValue:
+ ds.decodeMap(os, obj, x.MapValue)
+ case *pb.Object_InterfaceValue:
+ ds.decodeInterface(os, obj, x.InterfaceValue)
+ case *pb.Object_ByteArrayValue:
+ copyArray(obj, reflect.ValueOf(x.ByteArrayValue))
+ case *pb.Object_Uint16ArrayValue:
+ // 16-bit slices are serialized as 32-bit slices.
+ // See object.proto for details.
+ s := x.Uint16ArrayValue.Values
+ t := obj.Slice(0, obj.Len()).Interface().([]uint16)
+ if len(t) != len(s) {
+ panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
+ }
+ for i := range s {
+ t[i] = uint16(s[i])
+ }
+ case *pb.Object_Uint32ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values))
+ case *pb.Object_Uint64ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values))
+ case *pb.Object_UintptrArrayValue:
+ copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0))))
+ case *pb.Object_Int8ArrayValue:
+ copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0))))
+ case *pb.Object_Int16ArrayValue:
+ // 16-bit slices are serialized as 32-bit slices.
+ // See object.proto for details.
+ s := x.Int16ArrayValue.Values
+ t := obj.Slice(0, obj.Len()).Interface().([]int16)
+ if len(t) != len(s) {
+ panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
+ }
+ for i := range s {
+ t[i] = int16(s[i])
+ }
+ case *pb.Object_Int32ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values))
+ case *pb.Object_Int64ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values))
+ case *pb.Object_BoolArrayValue:
+ copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values))
+ case *pb.Object_Float64ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values))
+ case *pb.Object_Float32ArrayValue:
+ copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values))
+ default:
+ // Shoud not happen, not propagated as an error.
+ panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type()))
+ }
+
+ ds.stats.Done()
+ ds.pop()
+}
+
+func copyArray(dest reflect.Value, src reflect.Value) {
+ if dest.Len() != src.Len() {
+ panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len()))
+ }
+ reflect.Copy(dest, castSlice(src, dest.Type().Elem()))
+}
+
+// Deserialize deserializes the object state.
+//
+// This function may panic and should be run in safely().
+func (ds *decodeState) Deserialize(obj reflect.Value) {
+ ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()}
+ ds.outstanding = 1 // The root object.
+
+ // Decode all objects in the stream.
+ //
+ // See above, we never process objects while we have no outstanding
+ // interests (other than the very first object).
+ for id := uint64(1); ds.outstanding > 0; id++ {
+ o, err := ds.readObject()
+ if err != nil {
+ panic(err)
+ }
+
+ os := ds.lookup(id)
+ if os != nil {
+ // Decode the object.
+ ds.from = &os.path
+ ds.decodeObject(os, os.obj, o, "", nil)
+ ds.outstanding--
+ } else {
+ // If an object hasn't had interest registered
+ // previously, we deferred decoding until interest is
+ // registered.
+ ds.deferred[id] = o
+ }
+ }
+
+ // Check the zero-length header at the end.
+ length, object, err := ReadHeader(ds.r)
+ if err != nil {
+ panic(err)
+ }
+ if length != 0 {
+ panic(fmt.Sprintf("expected zero-length terminal, got %d", length))
+ }
+ if object {
+ panic("expected non-object terminal")
+ }
+
+ // Check if we have any deferred objects.
+ if count := len(ds.deferred); count > 0 {
+ // Shoud not happen, not propagated as an error.
+ panic(fmt.Sprintf("still have %d deferred objects", count))
+ }
+
+ // Scan and fire all callbacks.
+ for _, os := range ds.objectsByID {
+ os.checkComplete(ds.stats)
+ }
+
+ // Check if we have any remaining dependency cycles.
+ for _, os := range ds.objectsByID {
+ if !os.complete() {
+ // This must be the result of a dependency cycle.
+ cycle := os.findCycle()
+ var buf bytes.Buffer
+ buf.WriteString("dependency cycle: {")
+ for i, cycleOS := range cycle {
+ if i > 0 {
+ buf.WriteString(" => ")
+ }
+ buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type()))
+ }
+ buf.WriteString("}")
+ // Panic as an error; propagate to the caller.
+ panic(errors.New(string(buf.Bytes())))
+ }
+ }
+}
+
+type byteReader struct {
+ io.Reader
+}
+
+// ReadByte implements io.ByteReader.
+func (br byteReader) ReadByte() (byte, error) {
+ var b [1]byte
+ n, err := br.Reader.Read(b[:])
+ if n > 0 {
+ return b[0], nil
+ } else if err != nil {
+ return 0, err
+ } else {
+ return 0, io.ErrUnexpectedEOF
+ }
+}
+
+// ReadHeader reads an object header.
+//
+// Each object written to the statefile is prefixed with a header. See
+// WriteHeader for more information; these functions are exported to allow
+// non-state writes to the file to play nice with debugging tools.
+func ReadHeader(r io.Reader) (length uint64, object bool, err error) {
+ // Read the header.
+ length, err = binary.ReadUvarint(byteReader{r})
+ if err != nil {
+ return
+ }
+
+ // Decode whether the object is valid.
+ object = length&0x1 != 0
+ length = length >> 1
+ return
+}
+
+// readObject reads an object from the stream.
+func (ds *decodeState) readObject() (*pb.Object, error) {
+ // Read the header.
+ length, object, err := ReadHeader(ds.r)
+ if err != nil {
+ return nil, err
+ }
+ if !object {
+ return nil, fmt.Errorf("invalid object header")
+ }
+
+ // Read the object.
+ buf := make([]byte, length)
+ for done := 0; done < len(buf); {
+ n, err := ds.r.Read(buf[done:])
+ done += n
+ if n == 0 && err != nil {
+ return nil, err
+ }
+ }
+
+ // Unmarshal.
+ obj := new(pb.Object)
+ if err := proto.Unmarshal(buf, obj); err != nil {
+ return nil, err
+ }
+
+ return obj, nil
+}
diff --git a/pkg/state/encode.go b/pkg/state/encode.go
new file mode 100644
index 000000000..eb6527afc
--- /dev/null
+++ b/pkg/state/encode.go
@@ -0,0 +1,454 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "container/list"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "reflect"
+ "sort"
+
+ "github.com/golang/protobuf/proto"
+ pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto"
+)
+
+// queuedObject is an object queued for encoding.
+type queuedObject struct {
+ id uint64
+ obj reflect.Value
+ path recoverable
+}
+
+// encodeState is state used for encoding.
+//
+// The encoding process is a breadth-first traversal of the object graph. The
+// inherent races and dependencies are much simpler than the decode case.
+type encodeState struct {
+ // lastID is the last object ID.
+ //
+ // See idsByObject for context. Because of the special zero encoding
+ // used for reference values, the first ID must be 1.
+ lastID uint64
+
+ // idsByObject is a set of objects, indexed via:
+ //
+ // reflect.ValueOf(x).UnsafeAddr
+ //
+ // This provides IDs for objects.
+ idsByObject map[uintptr]uint64
+
+ // values stores values that span the addresses.
+ //
+ // addrSet is a a generated type which efficiently stores ranges of
+ // addresses. When encoding pointers, these ranges are filled in and
+ // used to check for overlapping or conflicting pointers. This would
+ // indicate a pointer to an field, or a non-type safe value, neither of
+ // which are currently decodable.
+ //
+ // See the usage of values below for more context.
+ values addrSet
+
+ // w is the output stream.
+ w io.Writer
+
+ // pending is the list of objects to be serialized.
+ //
+ // This is a set of queuedObjects.
+ pending list.List
+
+ // done is the a list of finished objects.
+ //
+ // This is kept to prevent garbage collection and address reuse.
+ done list.List
+
+ // stats is the passed stats object.
+ stats *Stats
+
+ // recoverable is the panic recover facility.
+ recoverable
+}
+
+// register looks up an ID, registering if necessary.
+//
+// If the object was not previosly registered, it is enqueued to be serialized.
+// See the documentation for idsByObject for more information.
+func (es *encodeState) register(obj reflect.Value) uint64 {
+ // It is not legal to call register for any non-pointer objects (see
+ // below), so we panic with a recoverable error if this is a mismatch.
+ if obj.Kind() != reflect.Ptr && obj.Kind() != reflect.Map {
+ panic(fmt.Errorf("non-pointer %#v registered", obj.Interface()))
+ }
+
+ addr := obj.Pointer()
+ if obj.Kind() == reflect.Ptr && obj.Elem().Type().Size() == 0 {
+ // For zero-sized objects, we always provide a unique ID.
+ // That's because the runtime internally multiplexes pointers
+ // to the same address. We can't be certain what the intent is
+ // with pointers to zero-sized objects, so we just give them
+ // all unique identities.
+ } else if id, ok := es.idsByObject[addr]; ok {
+ // Already registered.
+ return id
+ }
+
+ // Ensure that the first ID given out is one. See note on lastID. The
+ // ID zero is used to indicate nil values.
+ es.lastID++
+ id := es.lastID
+ es.idsByObject[addr] = id
+ if obj.Kind() == reflect.Ptr {
+ // Dereference and treat as a pointer.
+ es.pending.PushBack(queuedObject{id: id, obj: obj.Elem(), path: es.recoverable.copy()})
+
+ // Register this object at all addresses.
+ typ := obj.Elem().Type()
+ if size := typ.Size(); size > 0 {
+ r := addrRange{addr, addr + size}
+ if !es.values.IsEmptyRange(r) {
+ panic(fmt.Errorf("overlapping objects: [new object] %#v [existing object] %#v", obj.Interface(), es.values.FindSegment(addr).Value().Elem().Interface()))
+ }
+ es.values.Add(r, obj)
+ }
+ } else {
+ // Push back the map itself; when maps are encoded from the
+ // top-level, forceMap will be equal to true.
+ es.pending.PushBack(queuedObject{id: id, obj: obj, path: es.recoverable.copy()})
+ }
+
+ return id
+}
+
+// encodeMap encodes a map.
+func (es *encodeState) encodeMap(obj reflect.Value) *pb.Map {
+ var (
+ keys []*pb.Object
+ values []*pb.Object
+ )
+ for i, k := range obj.MapKeys() {
+ v := obj.MapIndex(k)
+ kp := es.encodeObject(k, false, ".(key %d)", i)
+ vp := es.encodeObject(v, false, "[%#v]", k.Interface())
+ keys = append(keys, kp)
+ values = append(values, vp)
+ }
+ return &pb.Map{Keys: keys, Values: values}
+}
+
+// encodeStruct encodes a composite object.
+func (es *encodeState) encodeStruct(obj reflect.Value) *pb.Struct {
+ // Invoke the save.
+ m := Map{newInternalMap(es, nil, nil)}
+ defer internalMapPool.Put(m.internalMap)
+ if !obj.CanAddr() {
+ // Force it to a * type of the above; this involves a copy.
+ localObj := reflect.New(obj.Type())
+ localObj.Elem().Set(obj)
+ obj = localObj.Elem()
+ }
+ fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
+ if ok {
+ // Invoke the provided saver.
+ fns.invokeSave(obj.Addr(), m)
+ } else if obj.NumField() == 0 {
+ // Allow unregistered anonymous, empty structs.
+ return &pb.Struct{}
+ } else {
+ // Propagate an error.
+ panic(fmt.Errorf("unregistered type %T", obj.Interface()))
+ }
+
+ // Sort the underlying slice, and check for duplicates. This is done
+ // once instead of on each add, because performing this sort once is
+ // far more efficient.
+ if len(m.data) > 1 {
+ sort.Slice(m.data, func(i, j int) bool {
+ return m.data[i].name < m.data[j].name
+ })
+ for i := range m.data {
+ if i > 0 && m.data[i-1].name == m.data[i].name {
+ panic(fmt.Errorf("duplicate name %s", m.data[i].name))
+ }
+ }
+ }
+
+ // Encode the resulting fields.
+ fields := make([]*pb.Field, 0, len(m.data))
+ for _, e := range m.data {
+ fields = append(fields, &pb.Field{
+ Name: e.name,
+ Value: e.object,
+ })
+ }
+
+ // Return the encoded object.
+ return &pb.Struct{Fields: fields}
+}
+
+// encodeArray encodes an array.
+func (es *encodeState) encodeArray(obj reflect.Value) *pb.Array {
+ var (
+ contents []*pb.Object
+ )
+ for i := 0; i < obj.Len(); i++ {
+ entry := es.encodeObject(obj.Index(i), false, "[%d]", i)
+ contents = append(contents, entry)
+ }
+ return &pb.Array{Contents: contents}
+}
+
+// encodeInterface encodes an interface.
+//
+// Precondition: the value is not nil.
+func (es *encodeState) encodeInterface(obj reflect.Value) *pb.Interface {
+ // Check for the nil interface.
+ obj = reflect.ValueOf(obj.Interface())
+ if !obj.IsValid() {
+ return &pb.Interface{
+ Type: "", // left alone in decode.
+ Value: &pb.Object{Value: &pb.Object_RefValue{0}},
+ }
+ }
+ // We have an interface value here. How do we save that? We
+ // resolve the underlying type and save it as a dispatchable.
+ typName, ok := registeredTypes.lookupName(obj.Type())
+ if !ok {
+ panic(fmt.Errorf("type %s is not registered", obj.Type()))
+ }
+
+ // Encode the object again.
+ return &pb.Interface{
+ Type: typName,
+ Value: es.encodeObject(obj, false, ".(%s)", typName),
+ }
+}
+
+// encodeObject encodes an object.
+//
+// If mapAsValue is true, then a map will be encoded directly.
+func (es *encodeState) encodeObject(obj reflect.Value, mapAsValue bool, format string, param interface{}) (object *pb.Object) {
+ es.push(false, format, param)
+ es.stats.Start(obj)
+
+ switch obj.Kind() {
+ case reflect.Bool:
+ object = &pb.Object{Value: &pb.Object_BoolValue{obj.Bool()}}
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ object = &pb.Object{Value: &pb.Object_Int64Value{obj.Int()}}
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ object = &pb.Object{Value: &pb.Object_Uint64Value{obj.Uint()}}
+ case reflect.Float32, reflect.Float64:
+ object = &pb.Object{Value: &pb.Object_DoubleValue{obj.Float()}}
+ case reflect.Array:
+ switch obj.Type().Elem().Kind() {
+ case reflect.Uint8:
+ object = &pb.Object{Value: &pb.Object_ByteArrayValue{pbSlice(obj).Interface().([]byte)}}
+ case reflect.Uint16:
+ // 16-bit slices are serialized as 32-bit slices.
+ // See object.proto for details.
+ s := pbSlice(obj).Interface().([]uint16)
+ t := make([]uint32, len(s))
+ for i := range s {
+ t[i] = uint32(s[i])
+ }
+ object = &pb.Object{Value: &pb.Object_Uint16ArrayValue{&pb.Uint16S{Values: t}}}
+ case reflect.Uint32:
+ object = &pb.Object{Value: &pb.Object_Uint32ArrayValue{&pb.Uint32S{Values: pbSlice(obj).Interface().([]uint32)}}}
+ case reflect.Uint64:
+ object = &pb.Object{Value: &pb.Object_Uint64ArrayValue{&pb.Uint64S{Values: pbSlice(obj).Interface().([]uint64)}}}
+ case reflect.Uintptr:
+ object = &pb.Object{Value: &pb.Object_UintptrArrayValue{&pb.Uintptrs{Values: pbSlice(obj).Interface().([]uint64)}}}
+ case reflect.Int8:
+ object = &pb.Object{Value: &pb.Object_Int8ArrayValue{&pb.Int8S{Values: pbSlice(obj).Interface().([]byte)}}}
+ case reflect.Int16:
+ // 16-bit slices are serialized as 32-bit slices.
+ // See object.proto for details.
+ s := pbSlice(obj).Interface().([]int16)
+ t := make([]int32, len(s))
+ for i := range s {
+ t[i] = int32(s[i])
+ }
+ object = &pb.Object{Value: &pb.Object_Int16ArrayValue{&pb.Int16S{Values: t}}}
+ case reflect.Int32:
+ object = &pb.Object{Value: &pb.Object_Int32ArrayValue{&pb.Int32S{Values: pbSlice(obj).Interface().([]int32)}}}
+ case reflect.Int64:
+ object = &pb.Object{Value: &pb.Object_Int64ArrayValue{&pb.Int64S{Values: pbSlice(obj).Interface().([]int64)}}}
+ case reflect.Bool:
+ object = &pb.Object{Value: &pb.Object_BoolArrayValue{&pb.Bools{Values: pbSlice(obj).Interface().([]bool)}}}
+ case reflect.Float32:
+ object = &pb.Object{Value: &pb.Object_Float32ArrayValue{&pb.Float32S{Values: pbSlice(obj).Interface().([]float32)}}}
+ case reflect.Float64:
+ object = &pb.Object{Value: &pb.Object_Float64ArrayValue{&pb.Float64S{Values: pbSlice(obj).Interface().([]float64)}}}
+ default:
+ object = &pb.Object{Value: &pb.Object_ArrayValue{es.encodeArray(obj)}}
+ }
+ case reflect.Slice:
+ if obj.IsNil() || obj.Cap() == 0 {
+ // Handled specially in decode; store as nil value.
+ object = &pb.Object{Value: &pb.Object_RefValue{0}}
+ } else {
+ // Serialize a slice as the array plus length and capacity.
+ object = &pb.Object{Value: &pb.Object_SliceValue{&pb.Slice{
+ Capacity: uint32(obj.Cap()),
+ Length: uint32(obj.Len()),
+ RefValue: es.register(arrayFromSlice(obj)),
+ }}}
+ }
+ case reflect.String:
+ object = &pb.Object{Value: &pb.Object_StringValue{obj.String()}}
+ case reflect.Ptr:
+ if obj.IsNil() {
+ // Handled specially in decode; store as a nil value.
+ object = &pb.Object{Value: &pb.Object_RefValue{0}}
+ } else {
+ es.push(true /* dereference */, "", nil)
+ object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}}
+ es.pop()
+ }
+ case reflect.Interface:
+ // We don't check for IsNil here, as we want to encode type
+ // information. The case of the empty interface (no type, no
+ // value) is handled by encodeInteface.
+ object = &pb.Object{Value: &pb.Object_InterfaceValue{es.encodeInterface(obj)}}
+ case reflect.Struct:
+ object = &pb.Object{Value: &pb.Object_StructValue{es.encodeStruct(obj)}}
+ case reflect.Map:
+ if obj.IsNil() {
+ // Handled specially in decode; store as a nil value.
+ object = &pb.Object{Value: &pb.Object_RefValue{0}}
+ } else if mapAsValue {
+ // Encode the map directly.
+ object = &pb.Object{Value: &pb.Object_MapValue{es.encodeMap(obj)}}
+ } else {
+ // Encode a reference to the map.
+ object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}}
+ }
+ default:
+ panic(fmt.Errorf("unknown primitive %#v", obj.Interface()))
+ }
+
+ es.stats.Done()
+ es.pop()
+ return
+}
+
+// Serialize serializes the object state.
+//
+// This function may panic and should be run in safely().
+func (es *encodeState) Serialize(obj reflect.Value) {
+ es.register(obj.Addr())
+
+ // Pop off the list until we're done.
+ for es.pending.Len() > 0 {
+ e := es.pending.Front()
+ es.pending.Remove(e)
+
+ // Extract the queued object.
+ qo := e.Value.(queuedObject)
+ es.from = &qo.path
+ o := es.encodeObject(qo.obj, true, "", nil)
+
+ // Emit to our output stream.
+ if err := es.writeObject(qo.id, o); err != nil {
+ panic(err)
+ }
+
+ // Mark as done.
+ es.done.PushBack(e)
+ }
+
+ // Write a zero-length terminal at the end; this is a sanity check
+ // applied at decode time as well (see decode.go).
+ if err := WriteHeader(es.w, 0, false); err != nil {
+ panic(err)
+ }
+}
+
+// WriteHeader writes a header.
+//
+// Each object written to the statefile should be prefixed with a header. In
+// order to generate statefiles that play nicely with debugging tools, raw
+// writes should be prefixed with a header with object set to false and the
+// appropriate length. This will allow tools to skip these regions.
+func WriteHeader(w io.Writer, length uint64, object bool) error {
+ // The lowest-order bit encodes whether this is a valid object. This is
+ // a purely internal convention, but allows the object flag to be
+ // returned from ReadHeader.
+ length = length << 1
+ if object {
+ length |= 0x1
+ }
+
+ // Write a header.
+ var hdr [32]byte
+ encodedLen := binary.PutUvarint(hdr[:], length)
+ for done := 0; done < encodedLen; {
+ n, err := w.Write(hdr[done:encodedLen])
+ done += n
+ if n == 0 && err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// writeObject writes an object to the stream.
+func (es *encodeState) writeObject(id uint64, obj *pb.Object) error {
+ // Marshal the proto.
+ buf, err := proto.Marshal(obj)
+ if err != nil {
+ return err
+ }
+
+ // Write the object header.
+ if err := WriteHeader(es.w, uint64(len(buf)), true); err != nil {
+ return err
+ }
+
+ // Write the object.
+ for done := 0; done < len(buf); {
+ n, err := es.w.Write(buf[done:])
+ done += n
+ if n == 0 && err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// addrSetFunctions is used by addrSet.
+type addrSetFunctions struct{}
+
+func (addrSetFunctions) MinKey() uintptr {
+ return 0
+}
+
+func (addrSetFunctions) MaxKey() uintptr {
+ return ^uintptr(0)
+}
+
+func (addrSetFunctions) ClearValue(val *reflect.Value) {
+}
+
+func (addrSetFunctions) Merge(_ addrRange, val1 reflect.Value, _ addrRange, val2 reflect.Value) (reflect.Value, bool) {
+ return val1, val1 == val2
+}
+
+func (addrSetFunctions) Split(_ addrRange, val reflect.Value, _ uintptr) (reflect.Value, reflect.Value) {
+ return val, val
+}
diff --git a/pkg/state/encode_unsafe.go b/pkg/state/encode_unsafe.go
new file mode 100644
index 000000000..d96ba56d4
--- /dev/null
+++ b/pkg/state/encode_unsafe.go
@@ -0,0 +1,81 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// arrayFromSlice constructs a new pointer to the slice data.
+//
+// It would be similar to the following:
+//
+// x := make([]Foo, l, c)
+// a := ([l]Foo*)(unsafe.Pointer(x[0]))
+//
+func arrayFromSlice(obj reflect.Value) reflect.Value {
+ return reflect.NewAt(
+ reflect.ArrayOf(obj.Cap(), obj.Type().Elem()),
+ unsafe.Pointer(obj.Pointer()))
+}
+
+// pbSlice returns a protobuf-supported slice of the array and erase the
+// original element type (which could be a defined type or non-supported type).
+func pbSlice(obj reflect.Value) reflect.Value {
+ var typ reflect.Type
+ switch obj.Type().Elem().Kind() {
+ case reflect.Uint8:
+ typ = reflect.TypeOf(byte(0))
+ case reflect.Uint16:
+ typ = reflect.TypeOf(uint16(0))
+ case reflect.Uint32:
+ typ = reflect.TypeOf(uint32(0))
+ case reflect.Uint64:
+ typ = reflect.TypeOf(uint64(0))
+ case reflect.Uintptr:
+ typ = reflect.TypeOf(uint64(0))
+ case reflect.Int8:
+ typ = reflect.TypeOf(byte(0))
+ case reflect.Int16:
+ typ = reflect.TypeOf(int16(0))
+ case reflect.Int32:
+ typ = reflect.TypeOf(int32(0))
+ case reflect.Int64:
+ typ = reflect.TypeOf(int64(0))
+ case reflect.Bool:
+ typ = reflect.TypeOf(bool(false))
+ case reflect.Float32:
+ typ = reflect.TypeOf(float32(0))
+ case reflect.Float64:
+ typ = reflect.TypeOf(float64(0))
+ default:
+ panic("slice element is not of basic value type")
+ }
+ return reflect.NewAt(
+ reflect.ArrayOf(obj.Len(), typ),
+ unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()),
+ ).Elem().Slice(0, obj.Len())
+}
+
+func castSlice(obj reflect.Value, elemTyp reflect.Type) reflect.Value {
+ if obj.Type().Elem().Size() != elemTyp.Size() {
+ panic("cannot cast slice into other element type of different size")
+ }
+ return reflect.NewAt(
+ reflect.ArrayOf(obj.Len(), elemTyp),
+ unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()),
+ ).Elem()
+}
diff --git a/pkg/state/map.go b/pkg/state/map.go
new file mode 100644
index 000000000..c3d165501
--- /dev/null
+++ b/pkg/state/map.go
@@ -0,0 +1,221 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "fmt"
+ "reflect"
+ "sort"
+ "sync"
+
+ pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto"
+)
+
+// entry is a single map entry.
+type entry struct {
+ name string
+ object *pb.Object
+}
+
+// internalMap is the internal Map state.
+//
+// These are recycled via a pool to avoid churn.
+type internalMap struct {
+ // es is encodeState.
+ es *encodeState
+
+ // ds is decodeState.
+ ds *decodeState
+
+ // os is current object being decoded.
+ //
+ // This will always be nil during encode.
+ os *objectState
+
+ // data stores the encoded values.
+ data []entry
+}
+
+var internalMapPool = sync.Pool{
+ New: func() interface{} {
+ return new(internalMap)
+ },
+}
+
+// newInternalMap returns a cached map.
+func newInternalMap(es *encodeState, ds *decodeState, os *objectState) *internalMap {
+ m := internalMapPool.Get().(*internalMap)
+ m.es = es
+ m.ds = ds
+ m.os = os
+ if m.data != nil {
+ m.data = m.data[:0]
+ }
+ return m
+}
+
+// Map is a generic state container.
+//
+// This is the object passed to Save and Load in order to store their state.
+//
+// Detailed documentation is available in individual methods.
+type Map struct {
+ *internalMap
+}
+
+// Save adds the given object to the map.
+//
+// You should pass always pointers to the object you are saving. For example:
+//
+// type X struct {
+// A int
+// B *int
+// }
+//
+// func (x *X) Save(m Map) {
+// m.Save("A", &x.A)
+// m.Save("B", &x.B)
+// }
+//
+// func (x *X) Load(m Map) {
+// m.Load("A", &x.A)
+// m.Load("B", &x.B)
+// }
+func (m Map) Save(name string, objPtr interface{}) {
+ m.save(name, reflect.ValueOf(objPtr).Elem(), ".%s")
+}
+
+// SaveValue adds the given object value to the map.
+//
+// This should be used for values where pointers are not available, or casts
+// are required during Save/Load.
+//
+// For example, if we want to cast external package type P.Foo to int64:
+//
+// type X struct {
+// A P.Foo
+// }
+//
+// func (x *X) Save(m Map) {
+// m.SaveValue("A", int64(x.A))
+// }
+//
+// func (x *X) Load(m Map) {
+// m.LoadValue("A", new(int64), func(x interface{}) {
+// x.A = P.Foo(x.(int64))
+// })
+// }
+func (m Map) SaveValue(name string, obj interface{}) {
+ m.save(name, reflect.ValueOf(obj), ".(value %s)")
+}
+
+// save is helper for the above. It takes the name of value to save the field
+// to, the field object (obj), and a format string that specifies how the
+// field's saving logic is dispatched from the struct (normal, value, etc.). The
+// format string should expect one string parameter, which is the name of the
+// field.
+func (m Map) save(name string, obj reflect.Value, format string) {
+ if m.es == nil {
+ // Not currently encoding.
+ m.Failf("no encode state for %q", name)
+ }
+
+ // Attempt the encode.
+ //
+ // These are sorted at the end, after all objects are added and will be
+ // sorted and checked for duplicates (see encodeStruct).
+ m.data = append(m.data, entry{
+ name: name,
+ object: m.es.encodeObject(obj, false, format, name),
+ })
+}
+
+// Load loads the given object from the map.
+//
+// See Save for an example.
+func (m Map) Load(name string, objPtr interface{}) {
+ m.load(name, reflect.ValueOf(objPtr), false, nil, ".%s")
+}
+
+// LoadWait loads the given objects from the map, and marks it as requiring all
+// AfterLoad executions to complete prior to running this object's AfterLoad.
+//
+// See Save for an example.
+func (m Map) LoadWait(name string, objPtr interface{}) {
+ m.load(name, reflect.ValueOf(objPtr), true, nil, ".(wait %s)")
+}
+
+// LoadValue loads the given object value from the map.
+//
+// See SaveValue for an example.
+func (m Map) LoadValue(name string, objPtr interface{}, fn func(interface{})) {
+ o := reflect.ValueOf(objPtr)
+ m.load(name, o, true, func() { fn(o.Elem().Interface()) }, ".(value %s)")
+}
+
+// load is helper for the above. It takes the name of value to load the field
+// from, the target field pointer (objPtr), whether load completion of the
+// struct depends on the field's load completion (wait), the load completion
+// logic (fn), and a format string that specifies how the field's loading logic
+// is dispatched from the struct (normal, wait, value, etc.). The format string
+// should expect one string parameter, which is the name of the field.
+func (m Map) load(name string, objPtr reflect.Value, wait bool, fn func(), format string) {
+ if m.ds == nil {
+ // Not currently decoding.
+ m.Failf("no decode state for %q", name)
+ }
+
+ // Find the object.
+ //
+ // These are sorted up front (and should appear in the state file
+ // sorted as well), so we can do a binary search here to ensure that
+ // large structs don't behave badly.
+ i := sort.Search(len(m.data), func(i int) bool {
+ return m.data[i].name >= name
+ })
+ if i >= len(m.data) || m.data[i].name != name {
+ // There is no data for this name?
+ m.Failf("no data found for %q", name)
+ }
+
+ // Perform the decode.
+ m.ds.decodeObject(m.os, objPtr.Elem(), m.data[i].object, format, name)
+ if wait {
+ // Mark this individual object a blocker.
+ m.ds.waitObject(m.os, m.data[i].object, fn)
+ }
+}
+
+// Failf fails the save or restore with the provided message. Processing will
+// stop after calling Failf, as the state package uses a panic & recover
+// mechanism for state errors. You should defer any cleanup required.
+func (m Map) Failf(format string, args ...interface{}) {
+ panic(fmt.Errorf(format, args...))
+}
+
+// AfterLoad schedules a function execution when all objects have been allocated
+// and their automated loading and customized load logic have been executed. fn
+// will not be executed until all of current object's dependencies' AfterLoad()
+// logic, if exist, have been executed.
+func (m Map) AfterLoad(fn func()) {
+ if m.ds == nil {
+ // Not currently decoding.
+ m.Failf("not decoding")
+ }
+
+ // Queue the local callback; this will execute when all of the above
+ // data dependencies have been cleared.
+ m.os.callbacks = append(m.os.callbacks, fn)
+}
diff --git a/pkg/state/object.proto b/pkg/state/object.proto
new file mode 100644
index 000000000..6595c5519
--- /dev/null
+++ b/pkg/state/object.proto
@@ -0,0 +1,140 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package gvisor.state.statefile;
+
+// Slice is a slice value.
+message Slice {
+ uint32 length = 1;
+ uint32 capacity = 2;
+ uint64 ref_value = 3;
+}
+
+// Array is an array value.
+message Array {
+ repeated Object contents = 1;
+}
+
+// Map is a map value.
+message Map {
+ repeated Object keys = 1;
+ repeated Object values = 2;
+}
+
+// Interface is an interface value.
+message Interface {
+ string type = 1;
+ Object value = 2;
+}
+
+// Struct is a basic composite value.
+message Struct {
+ repeated Field fields = 1;
+}
+
+// Field encodes a single field.
+message Field {
+ string name = 1;
+ Object value = 2;
+}
+
+// Uint16s encodes an uint16 array. To be used inside oneof structure.
+message Uint16s {
+ // There is no 16-bit type in protobuf so we use variable length 32-bit here.
+ repeated uint32 values = 1;
+}
+
+// Uint32s encodes an uint32 array. To be used inside oneof structure.
+message Uint32s {
+ repeated fixed32 values = 1;
+}
+
+// Uint64s encodes an uint64 array. To be used inside oneof structure.
+message Uint64s {
+ repeated fixed64 values = 1;
+}
+
+// Uintptrs encodes an uintptr array. To be used inside oneof structure.
+message Uintptrs {
+ repeated fixed64 values = 1;
+}
+
+// Int8s encodes an int8 array. To be used inside oneof structure.
+message Int8s {
+ bytes values = 1;
+}
+
+// Int16s encodes an int16 array. To be used inside oneof structure.
+message Int16s {
+ // There is no 16-bit type in protobuf so we use variable length 32-bit here.
+ repeated int32 values = 1;
+}
+
+// Int32s encodes an int32 array. To be used inside oneof structure.
+message Int32s {
+ repeated sfixed32 values = 1;
+}
+
+// Int64s encodes an int64 array. To be used inside oneof structure.
+message Int64s {
+ repeated sfixed64 values = 1;
+}
+
+// Bools encodes a boolean array. To be used inside oneof structure.
+message Bools {
+ repeated bool values = 1;
+}
+
+// Float64s encodes a float64 array. To be used inside oneof structure.
+message Float64s {
+ repeated double values = 1;
+}
+
+// Float32s encodes a float32 array. To be used inside oneof structure.
+message Float32s {
+ repeated float values = 1;
+}
+
+// Object are primitive encodings.
+//
+// Note that ref_value references an Object.id, below.
+message Object {
+ oneof value {
+ bool bool_value = 1;
+ string string_value = 2;
+ int64 int64_value = 3;
+ uint64 uint64_value = 4;
+ double double_value = 5;
+ uint64 ref_value = 6;
+ Slice slice_value = 7;
+ Array array_value = 8;
+ Interface interface_value = 9;
+ Struct struct_value = 10;
+ Map map_value = 11;
+ bytes byte_array_value = 12;
+ Uint16s uint16_array_value = 13;
+ Uint32s uint32_array_value = 14;
+ Uint64s uint64_array_value = 15;
+ Uintptrs uintptr_array_value = 16;
+ Int8s int8_array_value = 17;
+ Int16s int16_array_value = 18;
+ Int32s int32_array_value = 19;
+ Int64s int64_array_value = 20;
+ Bools bool_array_value = 21;
+ Float64s float64_array_value = 22;
+ Float32s float32_array_value = 23;
+ }
+}
diff --git a/pkg/state/printer.go b/pkg/state/printer.go
new file mode 100644
index 000000000..c61ec4a26
--- /dev/null
+++ b/pkg/state/printer.go
@@ -0,0 +1,188 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "strings"
+
+ "github.com/golang/protobuf/proto"
+ pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto"
+)
+
+// format formats a single object, for pretty-printing.
+func format(graph uint64, depth int, object *pb.Object, html bool) (string, bool) {
+ switch x := object.GetValue().(type) {
+ case *pb.Object_BoolValue:
+ return fmt.Sprintf("%t", x.BoolValue), x.BoolValue != false
+ case *pb.Object_StringValue:
+ return fmt.Sprintf("\"%s\"", x.StringValue), x.StringValue != ""
+ case *pb.Object_Int64Value:
+ return fmt.Sprintf("%d", x.Int64Value), x.Int64Value != 0
+ case *pb.Object_Uint64Value:
+ return fmt.Sprintf("%du", x.Uint64Value), x.Uint64Value != 0
+ case *pb.Object_DoubleValue:
+ return fmt.Sprintf("%f", x.DoubleValue), x.DoubleValue != 0.0
+ case *pb.Object_RefValue:
+ if x.RefValue == 0 {
+ return "nil", false
+ }
+ ref := fmt.Sprintf("g%dr%d", graph, x.RefValue)
+ if html {
+ ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref)
+ }
+ return ref, true
+ case *pb.Object_SliceValue:
+ if x.SliceValue.RefValue == 0 {
+ return "nil", false
+ }
+ ref := fmt.Sprintf("g%dr%d", graph, x.SliceValue.RefValue)
+ if html {
+ ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref)
+ }
+ return fmt.Sprintf("%s[:%d:%d]", ref, x.SliceValue.Length, x.SliceValue.Capacity), true
+ case *pb.Object_ArrayValue:
+ if len(x.ArrayValue.Contents) == 0 {
+ return "[]", false
+ }
+ items := make([]string, 0, len(x.ArrayValue.Contents)+2)
+ zeros := make([]string, 0) // used to eliminate zero entries.
+ items = append(items, "[")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ for i := 0; i < len(x.ArrayValue.Contents); i++ {
+ item, ok := format(graph, depth+1, x.ArrayValue.Contents[i], html)
+ if ok {
+ if len(zeros) > 0 {
+ items = append(items, zeros...)
+ zeros = nil
+ }
+ items = append(items, fmt.Sprintf("\t%s,", item))
+ } else {
+ zeros = append(zeros, fmt.Sprintf("\t%s,", item))
+ }
+ }
+ if len(zeros) > 0 {
+ items = append(items, fmt.Sprintf("\t... (%d zero),", len(zeros)))
+ }
+ items = append(items, "]")
+ return strings.Join(items, tabs), len(zeros) < len(x.ArrayValue.Contents)
+ case *pb.Object_StructValue:
+ if len(x.StructValue.Fields) == 0 {
+ return "struct{}", false
+ }
+ items := make([]string, 0, len(x.StructValue.Fields)+2)
+ items = append(items, "struct{")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ allZero := true
+ for _, field := range x.StructValue.Fields {
+ element, ok := format(graph, depth+1, field.Value, html)
+ allZero = allZero && !ok
+ items = append(items, fmt.Sprintf("\t%s: %s,", field.Name, element))
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), !allZero
+ case *pb.Object_MapValue:
+ if len(x.MapValue.Keys) == 0 {
+ return "map{}", false
+ }
+ items := make([]string, 0, len(x.MapValue.Keys)+2)
+ items = append(items, "map{")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ for i := 0; i < len(x.MapValue.Keys); i++ {
+ key, _ := format(graph, depth+1, x.MapValue.Keys[i], html)
+ value, _ := format(graph, depth+1, x.MapValue.Values[i], html)
+ items = append(items, fmt.Sprintf("\t%s: %s,", key, value))
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), true
+ case *pb.Object_InterfaceValue:
+ if x.InterfaceValue.Type == "" {
+ return "interface(nil){}", false
+ }
+ element, _ := format(graph, depth+1, x.InterfaceValue.Value, html)
+ return fmt.Sprintf("interface(\"%s\"){%s}", x.InterfaceValue.Type, element), true
+ }
+
+ // Should not happen, but tolerate.
+ return fmt.Sprintf("(unknown proto type: %T)", object.GetValue()), true
+}
+
+// PrettyPrint reads the state stream from r, and pretty prints to w.
+func PrettyPrint(w io.Writer, r io.Reader, html bool) error {
+ var (
+ // current graph ID.
+ graph uint64
+
+ // current object ID.
+ id uint64
+ )
+
+ if html {
+ fmt.Fprintf(w, "<pre>")
+ defer fmt.Fprintf(w, "</pre>")
+ }
+
+ for {
+ // Find the first object to begin generation.
+ length, object, err := ReadHeader(r)
+ if err == io.EOF {
+ // Nothing else to do.
+ break
+ } else if err != nil {
+ return err
+ }
+ if !object {
+ // Increment the graph number & reset the ID.
+ graph++
+ id = 0
+ if length > 0 {
+ fmt.Fprintf(w, "(%d bytes non-object data)\n", length)
+ io.Copy(ioutil.Discard, &io.LimitedReader{
+ R: r,
+ N: int64(length),
+ })
+ }
+ continue
+ }
+
+ // Read & unmarshal the object.
+ buf := make([]byte, length)
+ for done := 0; done < len(buf); {
+ n, err := r.Read(buf[done:])
+ done += n
+ if n == 0 && err != nil {
+ return err
+ }
+ }
+ obj := new(pb.Object)
+ if err := proto.Unmarshal(buf, obj); err != nil {
+ return err
+ }
+
+ id++ // First object must be one.
+ str, _ := format(graph, 0, obj, html)
+ tag := fmt.Sprintf("g%dr%d", graph, id)
+ if html {
+ tag = fmt.Sprintf("<a name=%s>%s</a>", tag, tag)
+ }
+ if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/state/state.go b/pkg/state/state.go
new file mode 100644
index 000000000..23a0b5922
--- /dev/null
+++ b/pkg/state/state.go
@@ -0,0 +1,349 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package state provides functionality related to saving and loading object
+// graphs. For most types, it provides a set of default saving / loading logic
+// that will be invoked automatically if custom logic is not defined.
+//
+// Kind Support
+// ---- -------
+// Bool default
+// Int default
+// Int8 default
+// Int16 default
+// Int32 default
+// Int64 default
+// Uint default
+// Uint8 default
+// Uint16 default
+// Uint32 default
+// Uint64 default
+// Float32 default
+// Float64 default
+// Complex64 custom
+// Complex128 custom
+// Array default
+// Chan custom
+// Func custom
+// Interface custom
+// Map default (*)
+// Ptr default
+// Slice default
+// String default
+// Struct custom
+// UnsafePointer custom
+//
+// (*) Maps are treated as value types by this package, even if they are
+// pointers internally. If you want to save two independent references
+// to the same map value, you must explicitly use a pointer to a map.
+package state
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+ "runtime"
+
+ pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto"
+)
+
+// ErrState is returned when an error is encountered during encode/decode.
+type ErrState struct {
+ // Err is the underlying error.
+ Err error
+
+ // path is the visit path from root to the current object.
+ path string
+
+ // trace is the stack trace.
+ trace string
+}
+
+// Error returns a sensible description of the state error.
+func (e *ErrState) Error() string {
+ return fmt.Sprintf("%v:\nstate path: %s\n%s", e.Err, e.path, e.trace)
+}
+
+// Save saves the given object state.
+func Save(w io.Writer, rootPtr interface{}, stats *Stats) error {
+ // Create the encoding state.
+ es := &encodeState{
+ idsByObject: make(map[uintptr]uint64),
+ w: w,
+ stats: stats,
+ }
+
+ // Perform the encoding.
+ return es.safely(func() {
+ es.Serialize(reflect.ValueOf(rootPtr).Elem())
+ })
+}
+
+// Load loads a checkpoint.
+func Load(r io.Reader, rootPtr interface{}, stats *Stats) error {
+ // Create the decoding state.
+ ds := &decodeState{
+ objectsByID: make(map[uint64]*objectState),
+ deferred: make(map[uint64]*pb.Object),
+ r: r,
+ stats: stats,
+ }
+
+ // Attempt our decode.
+ return ds.safely(func() {
+ ds.Deserialize(reflect.ValueOf(rootPtr).Elem())
+ })
+}
+
+// Fns are the state dispatch functions.
+type Fns struct {
+ // Save is a function like Save(concreteType, Map).
+ Save interface{}
+
+ // Load is a function like Load(concreteType, Map).
+ Load interface{}
+}
+
+// Save executes the save function.
+func (fns *Fns) invokeSave(obj reflect.Value, m Map) {
+ reflect.ValueOf(fns.Save).Call([]reflect.Value{obj, reflect.ValueOf(m)})
+}
+
+// Load executes the load function.
+func (fns *Fns) invokeLoad(obj reflect.Value, m Map) {
+ reflect.ValueOf(fns.Load).Call([]reflect.Value{obj, reflect.ValueOf(m)})
+}
+
+// validateStateFn ensures types are correct.
+func validateStateFn(fn interface{}, typ reflect.Type) bool {
+ fnTyp := reflect.TypeOf(fn)
+ if fnTyp.Kind() != reflect.Func {
+ return false
+ }
+ if fnTyp.NumIn() != 2 {
+ return false
+ }
+ if fnTyp.NumOut() != 0 {
+ return false
+ }
+ if fnTyp.In(0) != typ {
+ return false
+ }
+ if fnTyp.In(1) != reflect.TypeOf(Map{}) {
+ return false
+ }
+ return true
+}
+
+// Validate validates all state functions.
+func (fns *Fns) Validate(typ reflect.Type) bool {
+ return validateStateFn(fns.Save, typ) && validateStateFn(fns.Load, typ)
+}
+
+type typeDatabase struct {
+ // nameToType is a forward lookup table.
+ nameToType map[string]reflect.Type
+
+ // typeToName is the reverse lookup table.
+ typeToName map[reflect.Type]string
+
+ // typeToFns is the function lookup table.
+ typeToFns map[reflect.Type]Fns
+}
+
+// registeredTypes is a database used for SaveInterface and LoadInterface.
+var registeredTypes = typeDatabase{
+ nameToType: make(map[string]reflect.Type),
+ typeToName: make(map[reflect.Type]string),
+ typeToFns: make(map[reflect.Type]Fns),
+}
+
+// register registers a type under the given name. This will generally be
+// called via init() methods, and therefore uses panic to propagate errors.
+func (t *typeDatabase) register(name string, typ reflect.Type, fns Fns) {
+ // We can't allow name collisions.
+ if ot, ok := t.nameToType[name]; ok {
+ panic(fmt.Sprintf("type %q can't use name %q, already in use by type %q", typ.Name(), name, ot.Name()))
+ }
+
+ // Or multiple registrations.
+ if on, ok := t.typeToName[typ]; ok {
+ panic(fmt.Sprintf("type %q can't be registered as %q, already registered as %q", typ.Name(), name, on))
+ }
+
+ t.nameToType[name] = typ
+ t.typeToName[typ] = name
+ t.typeToFns[typ] = fns
+}
+
+// lookupType finds a type given a name.
+func (t *typeDatabase) lookupType(name string) (reflect.Type, bool) {
+ typ, ok := t.nameToType[name]
+ return typ, ok
+}
+
+// lookupName finds a name given a type.
+func (t *typeDatabase) lookupName(typ reflect.Type) (string, bool) {
+ name, ok := t.typeToName[typ]
+ return name, ok
+}
+
+// lookupFns finds functions given a type.
+func (t *typeDatabase) lookupFns(typ reflect.Type) (Fns, bool) {
+ fns, ok := t.typeToFns[typ]
+ return fns, ok
+}
+
+// Register must be called for any interface implementation types that
+// implements Loader.
+//
+// Register should be called either immediately after startup or via init()
+// methods. Double registration of either names or types will result in a panic.
+//
+// No synchronization is provided; this should only be called in init.
+//
+// Example usage:
+//
+// state.Register("Foo", (*Foo)(nil), state.Fns{
+// Save: (*Foo).Save,
+// Load: (*Foo).Load,
+// })
+//
+func Register(name string, instance interface{}, fns Fns) {
+ registeredTypes.register(name, reflect.TypeOf(instance), fns)
+}
+
+// IsZeroValue checks if the given value is the zero value.
+//
+// This function is used by the stateify tool.
+func IsZeroValue(val interface{}) bool {
+ if val == nil {
+ return true
+ }
+ return reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface())
+}
+
+// step captures one encoding / decoding step. On each step, there is up to one
+// choice made, which is captured by non-nil param. We intentionally do not
+// eagerly create the final path string, as that will only be needed upon panic.
+type step struct {
+ // dereference indicate if the current object is obtained by
+ // dereferencing a pointer.
+ dereference bool
+
+ // format is the formatting string that takes param below, if
+ // non-nil. For example, in array indexing case, we have "[%d]".
+ format string
+
+ // param stores the choice made at the current encoding / decoding step.
+ // For eaxmple, in array indexing case, param stores the index. When no
+ // choice is made, e.g. dereference, param should be nil.
+ param interface{}
+}
+
+// recoverable is the state encoding / decoding panic recovery facility. It is
+// also used to store encoding / decoding steps as well as the reference to the
+// original queued object from which the current object is dispatched. The
+// complete encoding / decoding path is synthesised from the steps in all queued
+// objects leading to the current object.
+type recoverable struct {
+ from *recoverable
+ steps []step
+}
+
+// push enters a new context level.
+func (sr *recoverable) push(dereference bool, format string, param interface{}) {
+ sr.steps = append(sr.steps, step{dereference, format, param})
+}
+
+// pop exits the current context level.
+func (sr *recoverable) pop() {
+ if len(sr.steps) <= 1 {
+ return
+ }
+ sr.steps = sr.steps[:len(sr.steps)-1]
+}
+
+// path returns the complete encoding / decoding path from root. This is only
+// called upon panic.
+func (sr *recoverable) path() string {
+ if sr.from == nil {
+ return "root"
+ }
+ p := sr.from.path()
+ for _, s := range sr.steps {
+ if s.dereference {
+ p = fmt.Sprintf("*(%s)", p)
+ }
+ if s.param == nil {
+ p += s.format
+ } else {
+ p += fmt.Sprintf(s.format, s.param)
+ }
+ }
+ return p
+}
+
+func (sr *recoverable) copy() recoverable {
+ return recoverable{from: sr.from, steps: append([]step(nil), sr.steps...)}
+}
+
+// safely executes the given function, catching a panic and unpacking as an error.
+//
+// The error flow through the state package uses panic and recover. There are
+// two important reasons for this:
+//
+// 1) Many of the reflection methods will already panic with invalid data or
+// violated assumptions. We would want to recover anyways here.
+//
+// 2) It allows us to eliminate boilerplate within Save() and Load() functions.
+// In nearly all cases, when the low-level serialization functions fail, you
+// will want the checkpoint to fail anyways. Plumbing errors through every
+// method doesn't add a lot of value. If there are specific error conditions
+// that you'd like to handle, you should add appropriate functionality to
+// objects themselves prior to calling Save() and Load().
+func (sr *recoverable) safely(fn func()) (err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ es := new(ErrState)
+ if e, ok := r.(error); ok {
+ es.Err = e
+ } else {
+ es.Err = fmt.Errorf("%v", r)
+ }
+
+ es.path = sr.path()
+
+ // Make a stack. We don't know how big it will be ahead
+ // of time, but want to make sure we get the whole
+ // thing. So we just do a stupid brute force approach.
+ var stack []byte
+ for sz := 1024; ; sz *= 2 {
+ stack = make([]byte, sz)
+ n := runtime.Stack(stack, false)
+ if n < sz {
+ es.trace = string(stack[:n])
+ break
+ }
+ }
+
+ // Set the error.
+ err = es
+ }
+ }()
+
+ // Execute the function.
+ fn()
+ return nil
+}
diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go
new file mode 100644
index 000000000..d5a739f18
--- /dev/null
+++ b/pkg/state/state_test.go
@@ -0,0 +1,719 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "bytes"
+ "io/ioutil"
+ "math"
+ "reflect"
+ "testing"
+)
+
+// TestCase is used to define a single success/failure testcase of
+// serialization of a set of objects.
+type TestCase struct {
+ // Name is the name of the test case.
+ Name string
+
+ // Objects is the list of values to serialize.
+ Objects []interface{}
+
+ // Fail is whether the test case is supposed to fail or not.
+ Fail bool
+}
+
+// runTest runs all testcases.
+func runTest(t *testing.T, tests []TestCase) {
+ for _, test := range tests {
+ t.Logf("TEST %s:", test.Name)
+ for i, root := range test.Objects {
+ t.Logf(" case#%d: %#v", i, root)
+
+ // Save the passed object.
+ saveBuffer := &bytes.Buffer{}
+ saveObjectPtr := reflect.New(reflect.TypeOf(root))
+ saveObjectPtr.Elem().Set(reflect.ValueOf(root))
+ if err := Save(saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail {
+ t.Errorf(" FAIL: Save failed unexpectedly: %v", err)
+ continue
+ } else if err != nil {
+ t.Logf(" PASS: Save failed as expected: %v", err)
+ continue
+ }
+
+ // Load a new copy of the object.
+ loadObjectPtr := reflect.New(reflect.TypeOf(root))
+ if err := Load(bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail {
+ t.Errorf(" FAIL: Load failed unexpectedly: %v", err)
+ continue
+ } else if err != nil {
+ t.Logf(" PASS: Load failed as expected: %v", err)
+ continue
+ }
+
+ // Compare the values.
+ loadedValue := loadObjectPtr.Elem().Interface()
+ if eq := reflect.DeepEqual(root, loadedValue); !eq && !test.Fail {
+ t.Errorf(" FAIL: Objects differs; got %#v", loadedValue)
+ continue
+ } else if !eq {
+ t.Logf(" PASS: Object different as expected.")
+ continue
+ }
+
+ // Everything went okay. Is that good?
+ if test.Fail {
+ t.Errorf(" FAIL: Unexpected success.")
+ } else {
+ t.Logf(" PASS: Success.")
+ }
+ }
+ }
+}
+
+// dumbStruct is a struct which does not implement the loader/saver interface.
+// We expect that serialization of this struct will fail.
+type dumbStruct struct {
+ A int
+ B int
+}
+
+// smartStruct is a struct which does implement the loader/saver interface.
+// We expect that serialization of this struct will succeed.
+type smartStruct struct {
+ A int
+ B int
+}
+
+func (s *smartStruct) save(m Map) {
+ m.Save("A", &s.A)
+ m.Save("B", &s.B)
+}
+
+func (s *smartStruct) load(m Map) {
+ m.Load("A", &s.A)
+ m.Load("B", &s.B)
+}
+
+// valueLoadStruct uses a value load.
+type valueLoadStruct struct {
+ v int
+}
+
+func (v *valueLoadStruct) save(m Map) {
+ m.SaveValue("v", v.v)
+}
+
+func (v *valueLoadStruct) load(m Map) {
+ m.LoadValue("v", new(int), func(value interface{}) {
+ v.v = value.(int)
+ })
+}
+
+// afterLoadStruct has an AfterLoad function.
+type afterLoadStruct struct {
+ v int
+}
+
+func (a *afterLoadStruct) save(m Map) {
+}
+
+func (a *afterLoadStruct) load(m Map) {
+ m.AfterLoad(func() {
+ a.v++
+ })
+}
+
+// genericContainer is a generic dispatcher.
+type genericContainer struct {
+ v interface{}
+}
+
+func (g *genericContainer) save(m Map) {
+ m.Save("v", &g.v)
+}
+
+func (g *genericContainer) load(m Map) {
+ m.Load("v", &g.v)
+}
+
+// sliceContainer is a generic slice.
+type sliceContainer struct {
+ v []interface{}
+}
+
+func (s *sliceContainer) save(m Map) {
+ m.Save("v", &s.v)
+}
+
+func (s *sliceContainer) load(m Map) {
+ m.Load("v", &s.v)
+}
+
+// mapContainer is a generic map.
+type mapContainer struct {
+ v map[int]interface{}
+}
+
+func (mc *mapContainer) save(m Map) {
+ m.Save("v", &mc.v)
+}
+
+func (mc *mapContainer) load(m Map) {
+ // Some of the test cases below assume legacy behavior wherein maps
+ // will automatically inherit dependencies.
+ m.LoadWait("v", &mc.v)
+}
+
+// dumbMap is a map which does not implement the loader/saver interface.
+// Serialization of this map will default to the standard encode/decode logic.
+type dumbMap map[string]int
+
+// pointerStruct contains various pointers, shared and non-shared, and pointers
+// to pointers. We expect that serialization will respect the structure.
+type pointerStruct struct {
+ A *int
+ B *int
+ C *int
+ D *int
+
+ AA **int
+ BB **int
+}
+
+func (p *pointerStruct) save(m Map) {
+ m.Save("A", &p.A)
+ m.Save("B", &p.B)
+ m.Save("C", &p.C)
+ m.Save("D", &p.D)
+ m.Save("AA", &p.AA)
+ m.Save("BB", &p.BB)
+}
+
+func (p *pointerStruct) load(m Map) {
+ m.Load("A", &p.A)
+ m.Load("B", &p.B)
+ m.Load("C", &p.C)
+ m.Load("D", &p.D)
+ m.Load("AA", &p.AA)
+ m.Load("BB", &p.BB)
+}
+
+// testInterface is a trivial interface example.
+type testInterface interface {
+ Foo()
+}
+
+// testImpl is a trivial implementation of testInterface.
+type testImpl struct {
+}
+
+// Foo satisfies testInterface.
+func (t *testImpl) Foo() {
+}
+
+// testImpl is trivially serializable.
+func (t *testImpl) save(m Map) {
+}
+
+// testImpl is trivially serializable.
+func (t *testImpl) load(m Map) {
+}
+
+// testI demonstrates interface dispatching.
+type testI struct {
+ I testInterface
+}
+
+func (t *testI) save(m Map) {
+ m.Save("I", &t.I)
+}
+
+func (t *testI) load(m Map) {
+ m.Load("I", &t.I)
+}
+
+// cycleStruct is used to implement basic cycles.
+type cycleStruct struct {
+ c *cycleStruct
+}
+
+func (c *cycleStruct) save(m Map) {
+ m.Save("c", &c.c)
+}
+
+func (c *cycleStruct) load(m Map) {
+ m.Load("c", &c.c)
+}
+
+// badCycleStruct actually has deadlocking dependencies.
+//
+// This should pass if b.b = {nil|b} and fail otherwise.
+type badCycleStruct struct {
+ b *badCycleStruct
+}
+
+func (b *badCycleStruct) save(m Map) {
+ m.Save("b", &b.b)
+}
+
+func (b *badCycleStruct) load(m Map) {
+ m.LoadWait("b", &b.b)
+ m.AfterLoad(func() {
+ // This is not executable, since AfterLoad requires that the
+ // object and all dependencies are complete. This should cause
+ // a deadlock error during load.
+ })
+}
+
+// emptyStructPointer points to an empty struct.
+type emptyStructPointer struct {
+ nothing *struct{}
+}
+
+func (e *emptyStructPointer) save(m Map) {
+ m.Save("nothing", &e.nothing)
+}
+
+func (e *emptyStructPointer) load(m Map) {
+ m.Load("nothing", &e.nothing)
+}
+
+// truncateInteger truncates an integer.
+type truncateInteger struct {
+ v int64
+ v2 int32
+}
+
+func (t *truncateInteger) save(m Map) {
+ t.v2 = int32(t.v)
+ m.Save("v", &t.v)
+}
+
+func (t *truncateInteger) load(m Map) {
+ m.Load("v", &t.v2)
+ t.v = int64(t.v2)
+}
+
+// truncateUnsignedInteger truncates an unsigned integer.
+type truncateUnsignedInteger struct {
+ v uint64
+ v2 uint32
+}
+
+func (t *truncateUnsignedInteger) save(m Map) {
+ t.v2 = uint32(t.v)
+ m.Save("v", &t.v)
+}
+
+func (t *truncateUnsignedInteger) load(m Map) {
+ m.Load("v", &t.v2)
+ t.v = uint64(t.v2)
+}
+
+// truncateFloat truncates a floating point number.
+type truncateFloat struct {
+ v float64
+ v2 float32
+}
+
+func (t *truncateFloat) save(m Map) {
+ t.v2 = float32(t.v)
+ m.Save("v", &t.v)
+}
+
+func (t *truncateFloat) load(m Map) {
+ m.Load("v", &t.v2)
+ t.v = float64(t.v2)
+}
+
+func TestTypes(t *testing.T) {
+ // x and y are basic integers, while xp points to x.
+ x := 1
+ y := 2
+ xp := &x
+
+ // cs is a single object cycle.
+ cs := cycleStruct{nil}
+ cs.c = &cs
+
+ // cs1 and cs2 are in a two object cycle.
+ cs1 := cycleStruct{nil}
+ cs2 := cycleStruct{nil}
+ cs1.c = &cs2
+ cs2.c = &cs1
+
+ // bs is a single object cycle.
+ bs := badCycleStruct{nil}
+ bs.b = &bs
+
+ // bs2 and bs2 are in a deadlocking cycle.
+ bs1 := badCycleStruct{nil}
+ bs2 := badCycleStruct{nil}
+ bs1.b = &bs2
+ bs2.b = &bs1
+
+ // regular nils.
+ var (
+ nilmap dumbMap
+ nilslice []byte
+ )
+
+ // embed points to embedded fields.
+ embed1 := pointerStruct{}
+ embed1.AA = &embed1.A
+ embed2 := pointerStruct{}
+ embed2.BB = &embed2.B
+
+ // es1 contains two structs pointing to the same empty struct.
+ es := emptyStructPointer{new(struct{})}
+ es1 := []emptyStructPointer{es, es}
+
+ tests := []TestCase{
+ {
+ Name: "bool",
+ Objects: []interface{}{
+ true,
+ false,
+ },
+ },
+ {
+ Name: "integers",
+ Objects: []interface{}{
+ int(0),
+ int(1),
+ int(-1),
+ int8(0),
+ int8(1),
+ int8(-1),
+ int16(0),
+ int16(1),
+ int16(-1),
+ int32(0),
+ int32(1),
+ int32(-1),
+ int64(0),
+ int64(1),
+ int64(-1),
+ },
+ },
+ {
+ Name: "unsigned integers",
+ Objects: []interface{}{
+ uint(0),
+ uint(1),
+ uint8(0),
+ uint8(1),
+ uint16(0),
+ uint16(1),
+ uint32(1),
+ uint64(0),
+ uint64(1),
+ },
+ },
+ {
+ Name: "strings",
+ Objects: []interface{}{
+ "",
+ "foo",
+ "bar",
+ },
+ },
+ {
+ Name: "slices",
+ Objects: []interface{}{
+ []int{-1, 0, 1},
+ []*int{&x, &x, &x},
+ []int{1, 2, 3}[0:1],
+ []int{1, 2, 3}[1:2],
+ make([]byte, 32),
+ make([]byte, 32)[:16],
+ make([]byte, 32)[:16:20],
+ nilslice,
+ },
+ },
+ {
+ Name: "arrays",
+ Objects: []interface{}{
+ &[1048576]bool{false, true, false, true},
+ &[1048576]uint8{0, 1, 2, 3},
+ &[1048576]byte{0, 1, 2, 3},
+ &[1048576]uint16{0, 1, 2, 3},
+ &[1048576]uint{0, 1, 2, 3},
+ &[1048576]uint32{0, 1, 2, 3},
+ &[1048576]uint64{0, 1, 2, 3},
+ &[1048576]uintptr{0, 1, 2, 3},
+ &[1048576]int8{0, -1, -2, -3},
+ &[1048576]int16{0, -1, -2, -3},
+ &[1048576]int32{0, -1, -2, -3},
+ &[1048576]int64{0, -1, -2, -3},
+ &[1048576]float32{0, 1.1, 2.2, 3.3},
+ &[1048576]float64{0, 1.1, 2.2, 3.3},
+ },
+ },
+ {
+ Name: "pointers",
+ Objects: []interface{}{
+ &pointerStruct{A: &x, B: &x, C: &y, D: &y, AA: &xp, BB: &xp},
+ &pointerStruct{},
+ },
+ },
+ {
+ Name: "empty struct",
+ Objects: []interface{}{
+ struct{}{},
+ },
+ },
+ {
+ Name: "unenlightened structs",
+ Objects: []interface{}{
+ &dumbStruct{A: 1, B: 2},
+ },
+ Fail: true,
+ },
+ {
+ Name: "enlightened structs",
+ Objects: []interface{}{
+ &smartStruct{A: 1, B: 2},
+ },
+ },
+ {
+ Name: "load-hooks",
+ Objects: []interface{}{
+ &afterLoadStruct{v: 1},
+ &valueLoadStruct{v: 1},
+ &genericContainer{v: &afterLoadStruct{v: 1}},
+ &genericContainer{v: &valueLoadStruct{v: 1}},
+ &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}},
+ &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}},
+ &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}},
+ &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}},
+ },
+ },
+ {
+ Name: "maps",
+ Objects: []interface{}{
+ dumbMap{"a": -1, "b": 0, "c": 1},
+ map[smartStruct]int{{}: 0, {A: 1}: 1},
+ nilmap,
+ &mapContainer{v: map[int]interface{}{0: &smartStruct{A: 1}}},
+ },
+ },
+ {
+ Name: "interfaces",
+ Objects: []interface{}{
+ &testI{&testImpl{}},
+ &testI{nil},
+ &testI{(*testImpl)(nil)},
+ },
+ },
+ {
+ Name: "unregistered-interfaces",
+ Objects: []interface{}{
+ &genericContainer{v: afterLoadStruct{v: 1}},
+ &genericContainer{v: valueLoadStruct{v: 1}},
+ &sliceContainer{v: []interface{}{afterLoadStruct{v: 1}}},
+ &sliceContainer{v: []interface{}{valueLoadStruct{v: 1}}},
+ &mapContainer{v: map[int]interface{}{0: afterLoadStruct{v: 1}}},
+ &mapContainer{v: map[int]interface{}{0: valueLoadStruct{v: 1}}},
+ },
+ Fail: true,
+ },
+ {
+ Name: "cycles",
+ Objects: []interface{}{
+ &cs,
+ &cs1,
+ &cycleStruct{&cs1},
+ &cycleStruct{&cs},
+ &badCycleStruct{nil},
+ &bs,
+ },
+ },
+ {
+ Name: "deadlock",
+ Objects: []interface{}{
+ &bs1,
+ },
+ Fail: true,
+ },
+ {
+ Name: "embed",
+ Objects: []interface{}{
+ &embed1,
+ &embed2,
+ },
+ Fail: true,
+ },
+ {
+ Name: "empty structs",
+ Objects: []interface{}{
+ new(struct{}),
+ es,
+ es1,
+ },
+ },
+ {
+ Name: "truncated okay",
+ Objects: []interface{}{
+ &truncateInteger{v: 1},
+ &truncateUnsignedInteger{v: 1},
+ &truncateFloat{v: 1.0},
+ },
+ },
+ {
+ Name: "truncated bad",
+ Objects: []interface{}{
+ &truncateInteger{v: math.MaxInt32 + 1},
+ &truncateUnsignedInteger{v: math.MaxUint32 + 1},
+ &truncateFloat{v: math.MaxFloat32 * 2},
+ },
+ Fail: true,
+ },
+ }
+
+ runTest(t, tests)
+}
+
+// benchStruct is used for benchmarking.
+type benchStruct struct {
+ b *benchStruct
+
+ // Dummy data is included to ensure that these objects are large.
+ // This is to detect possible regression when registering objects.
+ _ [4096]byte
+}
+
+func (b *benchStruct) save(m Map) {
+ m.Save("b", &b.b)
+}
+
+func (b *benchStruct) load(m Map) {
+ m.LoadWait("b", &b.b)
+ m.AfterLoad(b.afterLoad)
+}
+
+func (b *benchStruct) afterLoad() {
+ // Do nothing, just force scheduling.
+}
+
+// buildObject builds a benchmark object.
+func buildObject(n int) (b *benchStruct) {
+ for i := 0; i < n; i++ {
+ b = &benchStruct{b: b}
+ }
+ return
+}
+
+func BenchmarkEncoding(b *testing.B) {
+ b.StopTimer()
+ bs := buildObject(b.N)
+ var stats Stats
+ b.StartTimer()
+ if err := Save(ioutil.Discard, bs, &stats); err != nil {
+ b.Errorf("save failed: %v", err)
+ }
+ b.StopTimer()
+ if b.N > 1000 {
+ b.Logf("breakdown (n=%d): %s", b.N, &stats)
+ }
+}
+
+func BenchmarkDecoding(b *testing.B) {
+ b.StopTimer()
+ bs := buildObject(b.N)
+ var newBS benchStruct
+ buf := &bytes.Buffer{}
+ if err := Save(buf, bs, nil); err != nil {
+ b.Errorf("save failed: %v", err)
+ }
+ var stats Stats
+ b.StartTimer()
+ if err := Load(buf, &newBS, &stats); err != nil {
+ b.Errorf("load failed: %v", err)
+ }
+ b.StopTimer()
+ if b.N > 1000 {
+ b.Logf("breakdown (n=%d): %s", b.N, &stats)
+ }
+}
+
+func init() {
+ Register("stateTest.smartStruct", (*smartStruct)(nil), Fns{
+ Save: (*smartStruct).save,
+ Load: (*smartStruct).load,
+ })
+ Register("stateTest.afterLoadStruct", (*afterLoadStruct)(nil), Fns{
+ Save: (*afterLoadStruct).save,
+ Load: (*afterLoadStruct).load,
+ })
+ Register("stateTest.valueLoadStruct", (*valueLoadStruct)(nil), Fns{
+ Save: (*valueLoadStruct).save,
+ Load: (*valueLoadStruct).load,
+ })
+ Register("stateTest.genericContainer", (*genericContainer)(nil), Fns{
+ Save: (*genericContainer).save,
+ Load: (*genericContainer).load,
+ })
+ Register("stateTest.sliceContainer", (*sliceContainer)(nil), Fns{
+ Save: (*sliceContainer).save,
+ Load: (*sliceContainer).load,
+ })
+ Register("stateTest.mapContainer", (*mapContainer)(nil), Fns{
+ Save: (*mapContainer).save,
+ Load: (*mapContainer).load,
+ })
+ Register("stateTest.pointerStruct", (*pointerStruct)(nil), Fns{
+ Save: (*pointerStruct).save,
+ Load: (*pointerStruct).load,
+ })
+ Register("stateTest.testImpl", (*testImpl)(nil), Fns{
+ Save: (*testImpl).save,
+ Load: (*testImpl).load,
+ })
+ Register("stateTest.testI", (*testI)(nil), Fns{
+ Save: (*testI).save,
+ Load: (*testI).load,
+ })
+ Register("stateTest.cycleStruct", (*cycleStruct)(nil), Fns{
+ Save: (*cycleStruct).save,
+ Load: (*cycleStruct).load,
+ })
+ Register("stateTest.badCycleStruct", (*badCycleStruct)(nil), Fns{
+ Save: (*badCycleStruct).save,
+ Load: (*badCycleStruct).load,
+ })
+ Register("stateTest.emptyStructPointer", (*emptyStructPointer)(nil), Fns{
+ Save: (*emptyStructPointer).save,
+ Load: (*emptyStructPointer).load,
+ })
+ Register("stateTest.truncateInteger", (*truncateInteger)(nil), Fns{
+ Save: (*truncateInteger).save,
+ Load: (*truncateInteger).load,
+ })
+ Register("stateTest.truncateUnsignedInteger", (*truncateUnsignedInteger)(nil), Fns{
+ Save: (*truncateUnsignedInteger).save,
+ Load: (*truncateUnsignedInteger).load,
+ })
+ Register("stateTest.truncateFloat", (*truncateFloat)(nil), Fns{
+ Save: (*truncateFloat).save,
+ Load: (*truncateFloat).load,
+ })
+ Register("stateTest.benchStruct", (*benchStruct)(nil), Fns{
+ Save: (*benchStruct).save,
+ Load: (*benchStruct).load,
+ })
+}
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD
new file mode 100644
index 000000000..df2c6a578
--- /dev/null
+++ b/pkg/state/statefile/BUILD
@@ -0,0 +1,23 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "statefile",
+ srcs = ["statefile.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/state/statefile",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/binary",
+ "//pkg/compressio",
+ "//pkg/hashio",
+ ],
+)
+
+go_test(
+ name = "statefile_test",
+ size = "small",
+ srcs = ["statefile_test.go"],
+ embed = [":statefile"],
+ deps = ["//pkg/hashio"],
+)
diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go
new file mode 100644
index 000000000..b25b743b7
--- /dev/null
+++ b/pkg/state/statefile/statefile.go
@@ -0,0 +1,233 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package statefile defines the state file data stream.
+//
+// This package currently does not include any details regarding the state
+// encoding itself, only details regarding state metadata and data layout.
+//
+// The file format is defined as follows.
+//
+// /------------------------------------------------------\
+// | header (8-bytes) |
+// +------------------------------------------------------+
+// | metadata length (8-bytes) |
+// +------------------------------------------------------+
+// | metadata |
+// +------------------------------------------------------+
+// | data |
+// \------------------------------------------------------/
+//
+// First, it includes a 8-byte magic header which is the following
+// sequence of bytes [0x67, 0x56, 0x69, 0x73, 0x6f, 0x72, 0x53, 0x46]
+//
+// This header is followed by an 8-byte length N (big endian), and an
+// ASCII-encoded JSON map that is exactly N bytes long.
+//
+// This map includes only strings for keys and strings for values. Keys in the
+// map that begin with "_" are for internal use only. They may be read, but may
+// not be provided by the user. In the future, this metadata may contain some
+// information relating to the state encoding itself.
+//
+// After the map, the remainder of the file is the state data.
+package statefile
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/json"
+ "fmt"
+ "hash"
+ "io"
+ "strings"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/compressio"
+ "gvisor.googlesource.com/gvisor/pkg/hashio"
+)
+
+// keySize is the AES-256 key length.
+const keySize = 32
+
+// compressionChunkSize is the chunk size for compression.
+const compressionChunkSize = 1024 * 1024
+
+// maxMetadataSize is the size limit of metadata section.
+const maxMetadataSize = 16 * 1024 * 1024
+
+// magicHeader is the byte sequence beginning each file.
+var magicHeader = []byte("\x67\x56\x69\x73\x6f\x72\x53\x46")
+
+// ErrBadMagic is returned if the header does not match.
+var ErrBadMagic = fmt.Errorf("bad magic header")
+
+// ErrMetadataMissing is returned if the state file is missing mandatory metadata.
+var ErrMetadataMissing = fmt.Errorf("missing metadata")
+
+// ErrInvalidMetadataLength is returned if the metadata length is too large.
+var ErrInvalidMetadataLength = fmt.Errorf("metadata length invalid, maximum size is %d", maxMetadataSize)
+
+// ErrMetadataInvalid is returned if passed metadata is invalid.
+var ErrMetadataInvalid = fmt.Errorf("metadata invalid, can't start with _")
+
+// NewWriter returns a state data writer for a statefile.
+//
+// Note that the returned WriteCloser must be closed.
+func NewWriter(w io.Writer, key []byte, metadata map[string]string, compressionLevel int) (io.WriteCloser, error) {
+ if metadata == nil {
+ metadata = make(map[string]string)
+ }
+ for k := range metadata {
+ if strings.HasPrefix(k, "_") {
+ return nil, ErrMetadataInvalid
+ }
+ }
+
+ // Create our HMAC function.
+ h := hmac.New(sha256.New, key)
+ mw := io.MultiWriter(w, h)
+
+ // First, write the header.
+ if _, err := mw.Write(magicHeader); err != nil {
+ return nil, err
+ }
+
+ // Generate a timestamp, for convenience only.
+ metadata["_timestamp"] = time.Now().UTC().String()
+ defer delete(metadata, "_timestamp")
+
+ // Write the metadata.
+ b, err := json.Marshal(metadata)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(b) > maxMetadataSize {
+ return nil, ErrInvalidMetadataLength
+ }
+
+ // Metadata length.
+ if err := binary.WriteUint64(mw, binary.BigEndian, uint64(len(b))); err != nil {
+ return nil, err
+ }
+ // Metadata bytes; io.MultiWriter will return a short write error if
+ // any of the writers returns < n.
+ if _, err := mw.Write(b); err != nil {
+ return nil, err
+ }
+ // Write the current hash.
+ cur := h.Sum(nil)
+ for done := 0; done < len(cur); {
+ n, err := mw.Write(cur[done:])
+ done += n
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ w = hashio.NewWriter(w, h)
+
+ // Wrap in compression.
+ return compressio.NewWriter(w, compressionChunkSize, compressionLevel)
+}
+
+// MetadataUnsafe reads out the metadata from a state file without verifying any
+// HMAC. This function shouldn't be called for untrusted input files.
+func MetadataUnsafe(r io.Reader) (map[string]string, error) {
+ return metadata(r, nil)
+}
+
+// metadata validates the magic header and reads out the metadata from a state
+// data stream.
+func metadata(r io.Reader, h hash.Hash) (map[string]string, error) {
+ if h != nil {
+ r = io.TeeReader(r, h)
+ }
+
+ // Read and validate magic header.
+ b := make([]byte, len(magicHeader))
+ if _, err := r.Read(b); err != nil {
+ return nil, err
+ }
+ if !bytes.Equal(b, magicHeader) {
+ return nil, ErrBadMagic
+ }
+
+ // Read and validate metadata.
+ b, err := func() (b []byte, err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ b = nil
+ err = fmt.Errorf("%v", r)
+ }
+ }()
+
+ metadataLen, err := binary.ReadUint64(r, binary.BigEndian)
+ if err != nil {
+ return nil, err
+ }
+ if metadataLen > maxMetadataSize {
+ return nil, ErrInvalidMetadataLength
+ }
+ b = make([]byte, int(metadataLen))
+ if _, err := io.ReadFull(r, b); err != nil {
+ return nil, err
+ }
+ return b, nil
+ }()
+ if err != nil {
+ return nil, err
+ }
+
+ if h != nil {
+ // Check the hash prior to decoding.
+ cur := h.Sum(nil)
+ buf := make([]byte, len(cur))
+ if _, err := io.ReadFull(r, buf); err != nil {
+ return nil, err
+ }
+ if !hmac.Equal(cur, buf) {
+ return nil, hashio.ErrHashMismatch
+ }
+ }
+
+ // Decode the metadata.
+ metadata := make(map[string]string)
+ if err := json.Unmarshal(b, &metadata); err != nil {
+ return nil, err
+ }
+
+ return metadata, nil
+}
+
+// NewReader returns a reader for a statefile.
+func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) {
+ // Read the metadata with the hash.
+ h := hmac.New(sha256.New, key)
+ metadata, err := metadata(r, h)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ r = hashio.NewReader(r, h)
+
+ // Wrap in compression.
+ rc, err := compressio.NewReader(r)
+ if err != nil {
+ return nil, nil, err
+ }
+ return rc, metadata, nil
+}
diff --git a/pkg/state/statefile/statefile_test.go b/pkg/state/statefile/statefile_test.go
new file mode 100644
index 000000000..6e67b51de
--- /dev/null
+++ b/pkg/state/statefile/statefile_test.go
@@ -0,0 +1,299 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package statefile
+
+import (
+ "bytes"
+ "compress/flate"
+ crand "crypto/rand"
+ "encoding/base64"
+ "io"
+ "math/rand"
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/hashio"
+)
+
+func randomKey() ([]byte, error) {
+ r := make([]byte, base64.RawStdEncoding.DecodedLen(keySize))
+ if _, err := io.ReadFull(crand.Reader, r); err != nil {
+ return nil, err
+ }
+ key := make([]byte, keySize)
+ base64.RawStdEncoding.Encode(key, r)
+ return key, nil
+}
+
+type testCase struct {
+ name string
+ data []byte
+ metadata map[string]string
+}
+
+func TestStatefile(t *testing.T) {
+ cases := []testCase{
+ // Various data sizes.
+ {"nil", nil, nil},
+ {"empty", []byte(""), nil},
+ {"some", []byte("_"), nil},
+ {"one", []byte("0"), nil},
+ {"two", []byte("01"), nil},
+ {"three", []byte("012"), nil},
+ {"four", []byte("0123"), nil},
+ {"five", []byte("01234"), nil},
+ {"six", []byte("012356"), nil},
+ {"seven", []byte("0123567"), nil},
+ {"eight", []byte("01235678"), nil},
+
+ // Make sure we have one longer than the hash length.
+ {"longer than hash", []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"), nil},
+
+ // Make sure we have one longer than the segment size.
+ {"segments", make([]byte, 3*hashio.SegmentSize), nil},
+ {"segments minus one", make([]byte, 3*hashio.SegmentSize-1), nil},
+ {"segments plus one", make([]byte, 3*hashio.SegmentSize+1), nil},
+ {"segments minus hash", make([]byte, 3*hashio.SegmentSize-32), nil},
+ {"segments plus hash", make([]byte, 3*hashio.SegmentSize+32), nil},
+ {"large", make([]byte, 30*hashio.SegmentSize), nil},
+
+ // Different metadata.
+ {"one metadata", []byte("data"), map[string]string{"foo": "bar"}},
+ {"two metadata", []byte("data"), map[string]string{"foo": "bar", "one": "two"}},
+ }
+
+ for _, c := range cases {
+ // Generate a key.
+ integrityKey, err := randomKey()
+ if err != nil {
+ t.Errorf("can't generate key: got %v, excepted nil", err)
+ continue
+ }
+
+ t.Run(c.name, func(t *testing.T) {
+ for _, key := range [][]byte{nil, integrityKey} {
+ t.Run("key="+string(key), func(t *testing.T) {
+ // Encoding happens via a buffer.
+ var bufEncoded bytes.Buffer
+ var bufDecoded bytes.Buffer
+
+ // Do all the writing.
+ w, err := NewWriter(&bufEncoded, key, c.metadata, flate.BestSpeed)
+ if err != nil {
+ t.Fatalf("error creating writer: got %v, expected nil", err)
+ }
+ if _, err := io.Copy(w, bytes.NewBuffer(c.data)); err != nil {
+ t.Fatalf("error during write: got %v, expected nil", err)
+ }
+
+ // Finish the sum.
+ if err := w.Close(); err != nil {
+ t.Fatalf("error during close: got %v, expected nil", err)
+ }
+
+ t.Logf("original data: %d bytes, encoded: %d bytes.",
+ len(c.data), len(bufEncoded.Bytes()))
+
+ // Do all the reading.
+ r, metadata, err := NewReader(bytes.NewReader(bufEncoded.Bytes()), key)
+ if err != nil {
+ t.Fatalf("error creating reader: got %v, expected nil", err)
+ }
+ if _, err := io.Copy(&bufDecoded, r); err != nil {
+ t.Fatalf("error during read: got %v, expected nil", err)
+ }
+
+ // Check that the data matches.
+ if !bytes.Equal(c.data, bufDecoded.Bytes()) {
+ t.Fatalf("data didn't match (%d vs %d bytes)", len(bufDecoded.Bytes()), len(c.data))
+ }
+
+ // Check that the metadata matches.
+ for k, v := range c.metadata {
+ nv, ok := metadata[k]
+ if !ok {
+ t.Fatalf("missing metadata: %s", k)
+ }
+ if v != nv {
+ t.Fatalf("mismatched metdata for %s: got %s, expected %s", k, nv, v)
+ }
+ }
+
+ // Change the data and verify that it fails.
+ b := append([]byte(nil), bufEncoded.Bytes()...)
+ b[rand.Intn(len(b))]++
+ r, _, err = NewReader(bytes.NewReader(b), key)
+ if err == nil {
+ _, err = io.Copy(&bufDecoded, r)
+ }
+ if err == nil {
+ t.Error("got no error: expected error on data corruption")
+ }
+
+ // Change the key and verify that it fails.
+ if key == nil {
+ key = integrityKey
+ } else {
+ key[rand.Intn(len(key))]++
+ }
+ r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), key)
+ if err == nil {
+ _, err = io.Copy(&bufDecoded, r)
+ }
+ if err != hashio.ErrHashMismatch {
+ t.Errorf("got error: %v, expected ErrHashMismatch on key mismatch", err)
+ }
+ })
+ }
+ })
+ }
+}
+
+const benchmarkDataSize = 10 * 1024 * 1024
+
+func benchmark(b *testing.B, size int, write bool, compressible bool) {
+ b.StopTimer()
+ b.SetBytes(benchmarkDataSize)
+
+ // Generate source data.
+ var source []byte
+ if compressible {
+ // For compressible data, we use essentially all zeros.
+ source = make([]byte, benchmarkDataSize)
+ } else {
+ // For non-compressible data, we use random base64 data (to
+ // make it marginally compressible, a ratio of 75%).
+ var sourceBuf bytes.Buffer
+ bufW := base64.NewEncoder(base64.RawStdEncoding, &sourceBuf)
+ bufR := rand.New(rand.NewSource(0))
+ if _, err := io.CopyN(bufW, bufR, benchmarkDataSize); err != nil {
+ b.Fatalf("unable to seed random data: %v", err)
+ }
+ source = sourceBuf.Bytes()
+ }
+
+ // Generate a random key for integrity check.
+ key, err := randomKey()
+ if err != nil {
+ b.Fatalf("error generating key: %v", err)
+ }
+
+ // Define our benchmark functions. Prior to running the readState
+ // function here, you must execute the writeState function at least
+ // once (done below).
+ var stateBuf bytes.Buffer
+ writeState := func() {
+ stateBuf.Reset()
+ w, err := NewWriter(&stateBuf, key, nil, flate.BestSpeed)
+ if err != nil {
+ b.Fatalf("error creating writer: %v", err)
+ }
+ for done := 0; done < len(source); {
+ chunk := size // limit size.
+ if done+chunk > len(source) {
+ chunk = len(source) - done
+ }
+ n, err := w.Write(source[done : done+chunk])
+ done += n
+ if n == 0 && err != nil {
+ b.Fatalf("error during write: %v", err)
+ }
+ }
+ if err := w.Close(); err != nil {
+ b.Fatalf("error closing writer: %v", err)
+ }
+ }
+ readState := func() {
+ tmpBuf := bytes.NewBuffer(stateBuf.Bytes())
+ r, _, err := NewReader(tmpBuf, key)
+ if err != nil {
+ b.Fatalf("error creating reader: %v", err)
+ }
+ for done := 0; done < len(source); {
+ chunk := size // limit size.
+ if done+chunk > len(source) {
+ chunk = len(source) - done
+ }
+ n, err := r.Read(source[done : done+chunk])
+ done += n
+ if n == 0 && err != nil {
+ b.Fatalf("error during read: %v", err)
+ }
+ }
+ }
+ // Generate the state once without timing to ensure that buffers have
+ // been appropriately allocated.
+ writeState()
+ if write {
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ writeState()
+ }
+ b.StopTimer()
+ } else {
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ readState()
+ }
+ b.StopTimer()
+ }
+}
+
+func BenchmarkWrite1BCompressible(b *testing.B) {
+ benchmark(b, 1, true, true)
+}
+
+func BenchmarkWrite1BNoncompressible(b *testing.B) {
+ benchmark(b, 1, true, false)
+}
+
+func BenchmarkWrite4KCompressible(b *testing.B) {
+ benchmark(b, 4096, true, true)
+}
+
+func BenchmarkWrite4KNoncompressible(b *testing.B) {
+ benchmark(b, 4096, true, false)
+}
+
+func BenchmarkWrite1MCompressible(b *testing.B) {
+ benchmark(b, 1024*1024, true, true)
+}
+
+func BenchmarkWrite1MNoncompressible(b *testing.B) {
+ benchmark(b, 1024*1024, true, false)
+}
+
+func BenchmarkRead1BCompressible(b *testing.B) {
+ benchmark(b, 1, false, true)
+}
+
+func BenchmarkRead1BNoncompressible(b *testing.B) {
+ benchmark(b, 1, false, false)
+}
+
+func BenchmarkRead4KCompressible(b *testing.B) {
+ benchmark(b, 4096, false, true)
+}
+
+func BenchmarkRead4KNoncompressible(b *testing.B) {
+ benchmark(b, 4096, false, false)
+}
+
+func BenchmarkRead1MCompressible(b *testing.B) {
+ benchmark(b, 1024*1024, false, true)
+}
+
+func BenchmarkRead1MNoncompressible(b *testing.B) {
+ benchmark(b, 1024*1024, false, false)
+}
diff --git a/pkg/state/stats.go b/pkg/state/stats.go
new file mode 100644
index 000000000..1ebd8ebb4
--- /dev/null
+++ b/pkg/state/stats.go
@@ -0,0 +1,133 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "bytes"
+ "fmt"
+ "reflect"
+ "sort"
+ "time"
+)
+
+type statEntry struct {
+ count uint
+ total time.Duration
+}
+
+// Stats tracks encode / decode timing.
+//
+// This currently provides a meaningful String function and no other way to
+// extract stats about individual types.
+//
+// All exported receivers accept nil.
+type Stats struct {
+ // byType contains a breakdown of time spent by type.
+ byType map[reflect.Type]*statEntry
+
+ // stack contains objects in progress.
+ stack []reflect.Type
+
+ // last is the last start time.
+ last time.Time
+}
+
+// sample adds the given number of samples to the given object.
+func (s *Stats) sample(typ reflect.Type, count uint) {
+ if s.byType == nil {
+ s.byType = make(map[reflect.Type]*statEntry)
+ }
+ entry, ok := s.byType[typ]
+ if !ok {
+ entry = new(statEntry)
+ s.byType[typ] = entry
+ }
+ now := time.Now()
+ entry.count += count
+ entry.total += now.Sub(s.last)
+ s.last = now
+}
+
+// Start starts a sample.
+func (s *Stats) Start(obj reflect.Value) {
+ if s == nil {
+ return
+ }
+ if len(s.stack) > 0 {
+ last := s.stack[len(s.stack)-1]
+ s.sample(last, 0)
+ } else {
+ // First time sample.
+ s.last = time.Now()
+ }
+ s.stack = append(s.stack, obj.Type())
+}
+
+// Done finishes the current sample.
+func (s *Stats) Done() {
+ if s == nil {
+ return
+ }
+ last := s.stack[len(s.stack)-1]
+ s.sample(last, 1)
+ s.stack = s.stack[:len(s.stack)-1]
+}
+
+type sliceEntry struct {
+ typ reflect.Type
+ entry *statEntry
+}
+
+// String returns a table representation of the stats.
+func (s *Stats) String() string {
+ if s == nil || len(s.byType) == 0 {
+ return "(no data)"
+ }
+
+ // Build a list of stat entries.
+ ss := make([]sliceEntry, 0, len(s.byType))
+ for typ, entry := range s.byType {
+ ss = append(ss, sliceEntry{
+ typ: typ,
+ entry: entry,
+ })
+ }
+
+ // Sort by total time (descending).
+ sort.Slice(ss, func(i, j int) bool {
+ return ss[i].entry.total > ss[j].entry.total
+ })
+
+ // Print the stat results.
+ var (
+ buf bytes.Buffer
+ count uint
+ total time.Duration
+ )
+ buf.WriteString("\n")
+ buf.WriteString(fmt.Sprintf("%12s | %8s | %8s | %s\n", "total", "count", "per", "type"))
+ buf.WriteString("-------------+----------+----------+-------------\n")
+ for _, se := range ss {
+ count += se.entry.count
+ total += se.entry.total
+ per := se.entry.total / time.Duration(se.entry.count)
+ buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | %s\n",
+ se.entry.total, se.entry.count, per, se.typ.String()))
+ }
+ buf.WriteString("-------------+----------+----------+-------------\n")
+ buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | [all]",
+ total, count, total/time.Duration(count)))
+ return string(buf.Bytes())
+}