diff options
Diffstat (limited to 'pkg/state')
-rw-r--r-- | pkg/state/BUILD | 77 | ||||
-rw-r--r-- | pkg/state/decode.go | 594 | ||||
-rw-r--r-- | pkg/state/encode.go | 454 | ||||
-rw-r--r-- | pkg/state/encode_unsafe.go | 81 | ||||
-rw-r--r-- | pkg/state/map.go | 221 | ||||
-rw-r--r-- | pkg/state/object.proto | 140 | ||||
-rw-r--r-- | pkg/state/printer.go | 188 | ||||
-rw-r--r-- | pkg/state/state.go | 349 | ||||
-rw-r--r-- | pkg/state/state_test.go | 719 | ||||
-rw-r--r-- | pkg/state/statefile/BUILD | 23 | ||||
-rw-r--r-- | pkg/state/statefile/statefile.go | 233 | ||||
-rw-r--r-- | pkg/state/statefile/statefile_test.go | 299 | ||||
-rw-r--r-- | pkg/state/stats.go | 133 |
13 files changed, 3511 insertions, 0 deletions
diff --git a/pkg/state/BUILD b/pkg/state/BUILD new file mode 100644 index 000000000..bb6415d9b --- /dev/null +++ b/pkg/state/BUILD @@ -0,0 +1,77 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +go_template_instance( + name = "addr_range", + out = "addr_range.go", + package = "state", + prefix = "addr", + template = "//pkg/segment:generic_range", + types = { + "T": "uintptr", + }, +) + +go_template_instance( + name = "addr_set", + out = "addr_set.go", + consts = { + "minDegree": "10", + }, + imports = { + "reflect": "reflect", + }, + package = "state", + prefix = "addr", + template = "//pkg/segment:generic_set", + types = { + "Key": "uintptr", + "Range": "addrRange", + "Value": "reflect.Value", + "Functions": "addrSetFunctions", + }, +) + +go_library( + name = "state", + srcs = [ + "addr_range.go", + "addr_set.go", + "decode.go", + "encode.go", + "encode_unsafe.go", + "map.go", + "printer.go", + "state.go", + "stats.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/state", + visibility = ["//:sandbox"], + deps = [ + ":object_go_proto", + "@com_github_golang_protobuf//proto:go_default_library", + ], +) + +proto_library( + name = "object_proto", + srcs = ["object.proto"], + visibility = ["//:sandbox"], +) + +go_proto_library( + name = "object_go_proto", + importpath = "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto", + proto = ":object_proto", + visibility = ["//:sandbox"], +) + +go_test( + name = "state_test", + size = "small", + srcs = ["state_test.go"], + embed = [":state"], +) diff --git a/pkg/state/decode.go b/pkg/state/decode.go new file mode 100644 index 000000000..05758495b --- /dev/null +++ b/pkg/state/decode.go @@ -0,0 +1,594 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "reflect" + "sort" + + "github.com/golang/protobuf/proto" + pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto" +) + +// objectState represents an object that may be in the process of being +// decoded. Specifically, it represents either a decoded object, or an an +// interest in a future object that will be decoded. When that interest is +// registered (via register), the storage for the object will be created, but +// it will not be decoded until the object is encountered in the stream. +type objectState struct { + // id is the id for this object. + // + // If this field is zero, then this is an anonymous (unregistered, + // non-reference primitive) object. This is immutable. + id uint64 + + // obj is the object. This may or may not be valid yet, depending on + // whether complete returns true. However, regardless of whether the + // object is valid, obj contains a final storage location for the + // object. This is immutable. + // + // Note that this must be addressable (obj.Addr() must not panic). + // + // The obj passed to the decode methods below will equal this obj only + // in the case of decoding the top-level object. However, the passed + // obj may represent individual fields, elements of a slice, etc. that + // are effectively embedded within the reflect.Value below but with + // distinct types. + obj reflect.Value + + // blockedBy is the number of dependencies this object has. + blockedBy int + + // blocking is a list of the objects blocked by this one. + blocking []*objectState + + // callbacks is a set of callbacks to execute on load. + callbacks []func() + + // path is the decoding path to the object. + path recoverable +} + +// complete indicates the object is complete. +func (os *objectState) complete() bool { + return os.blockedBy == 0 && len(os.callbacks) == 0 +} + +// checkComplete checks for completion. If the object is complete, pending +// callbacks will be executed and checkComplete will be called on downstream +// objects (those depending on this one). +func (os *objectState) checkComplete(stats *Stats) { + if os.blockedBy > 0 { + return + } + + // Fire all callbacks. + for _, fn := range os.callbacks { + stats.Start(os.obj) + fn() + stats.Done() + } + os.callbacks = nil + + // Clear all blocked objects. + for _, other := range os.blocking { + other.blockedBy-- + other.checkComplete(stats) + } + os.blocking = nil +} + +// waitFor queues a dependency on the given object. +func (os *objectState) waitFor(other *objectState, callback func()) { + os.blockedBy++ + other.blocking = append(other.blocking, os) + if callback != nil { + other.callbacks = append(other.callbacks, callback) + } +} + +// findCycleFor returns when the given object is found in the blocking set. +func (os *objectState) findCycleFor(target *objectState) []*objectState { + for _, other := range os.blocking { + if other == target { + return []*objectState{target} + } else if childList := other.findCycleFor(target); childList != nil { + return append(childList, other) + } + } + return nil +} + +// findCycle finds a dependency cycle. +func (os *objectState) findCycle() []*objectState { + return append(os.findCycleFor(os), os) +} + +// decodeState is a graph of objects in the process of being decoded. +// +// The decode process involves loading the breadth-first graph generated by +// encode. This graph is read in it's entirety, ensuring that all object +// storage is complete. +// +// As the graph is being serialized, a set of completion callbacks are +// executed. These completion callbacks should form a set of acyclic subgraphs +// over the original one. After decoding is complete, the objects are scanned +// to ensure that all callbacks are executed, otherwise the callback graph was +// not acyclic. +type decodeState struct { + // objectByID is the set of objects in progress. + objectsByID map[uint64]*objectState + + // deferred are objects that have been read, by no interest has been + // registered yet. These will be decoded once interest in registered. + deferred map[uint64]*pb.Object + + // outstanding is the number of outstanding objects. + outstanding uint32 + + // r is the input stream. + r io.Reader + + // stats is the passed stats object. + stats *Stats + + // recoverable is the panic recover facility. + recoverable +} + +// lookup looks up an object in decodeState or returns nil if no such object +// has been previously registered. +func (ds *decodeState) lookup(id uint64) *objectState { + return ds.objectsByID[id] +} + +// wait registers a dependency on an object. +// +// As a special case, we always allow _useable_ references back to the first +// decoding object because it may have fields that are already decoded. We also +// allow trivial self reference, since they can be handled internally. +func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) { + switch id { + case 0: + // Nil pointer; nothing to wait for. + fallthrough + case waiter.id: + // Trivial self reference. + fallthrough + case 1: + // Root object; see above. + if callback != nil { + callback() + } + return + } + + // No nil can be returned here. + waiter.waitFor(ds.lookup(id), callback) +} + +// waitObject notes a blocking relationship. +func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) { + if rv, ok := p.Value.(*pb.Object_RefValue); ok { + // Refs can encode pointers and maps. + ds.wait(os, rv.RefValue, callback) + } else if sv, ok := p.Value.(*pb.Object_SliceValue); ok { + // See decodeObject; we need to wait for the array (if non-nil). + ds.wait(os, sv.SliceValue.RefValue, callback) + } else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok { + // It's an interface (wait recurisvely). + ds.waitObject(os, iv.InterfaceValue.Value, callback) + } else if callback != nil { + // Nothing to wait for: execute the callback immediately. + callback() + } +} + +// register registers a decode with a type. +// +// This type is only used to instantiate a new object if it has not been +// registered previously. +func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState { + os, ok := ds.objectsByID[id] + if ok { + return os + } + + // Record in the object index. + if typ.Kind() == reflect.Map { + os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()} + } else { + os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()} + } + ds.objectsByID[id] = os + + if o, ok := ds.deferred[id]; ok { + // There is a deferred object. + delete(ds.deferred, id) // Free memory. + ds.decodeObject(os, os.obj, o, "", nil) + } else { + // There is no deferred object. + ds.outstanding++ + } + + return os +} + +// decodeStruct decodes a struct value. +func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) { + // Set the fields. + m := Map{newInternalMap(nil, ds, os)} + defer internalMapPool.Put(m.internalMap) + for _, field := range s.Fields { + m.data = append(m.data, entry{ + name: field.Name, + object: field.Value, + }) + } + + // Sort the fields for efficient searching. + // + // Technically, these should already appear in sorted order in the + // state ordering, so this cost is effectively a single scan to ensure + // that the order is correct. + if len(m.data) > 1 { + sort.Slice(m.data, func(i, j int) bool { + return m.data[i].name < m.data[j].name + }) + } + + // Invoke the load; this will recursively decode other objects. + fns, ok := registeredTypes.lookupFns(obj.Addr().Type()) + if ok { + // Invoke the loader. + fns.invokeLoad(obj.Addr(), m) + } else if obj.NumField() == 0 { + // Allow anonymous empty structs. + return + } else { + // Propagate an error. + panic(fmt.Errorf("unregistered type %s", obj.Type())) + } +} + +// decodeMap decodes a map value. +func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) { + if obj.IsNil() { + obj.Set(reflect.MakeMap(obj.Type())) + } + for i := 0; i < len(m.Keys); i++ { + // Decode the objects. + kv := reflect.New(obj.Type().Key()).Elem() + vv := reflect.New(obj.Type().Elem()).Elem() + ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i) + ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface()) + ds.waitObject(os, m.Keys[i], nil) + ds.waitObject(os, m.Values[i], nil) + + // Set in the map. + obj.SetMapIndex(kv, vv) + } +} + +// decodeArray decodes an array value. +func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) { + if len(a.Contents) != obj.Len() { + panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents))) + } + // Decode the contents into the array. + for i := 0; i < len(a.Contents); i++ { + ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i) + ds.waitObject(os, a.Contents[i], nil) + } +} + +// decodeInterface decodes an interface value. +func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) { + // Is this a nil value? + if i.Type == "" { + return // Just leave obj alone. + } + + // Get the dispatchable type. This may not be used if the given + // reference has already been resolved, but if not we need to know the + // type to create. + t, ok := registeredTypes.lookupType(i.Type) + if !ok { + panic(fmt.Errorf("no valid type for %q", i.Type)) + } + + if obj.Kind() != reflect.Map { + // Set the obj to be the given typed value; this actually sets + // obj to be a non-zero value -- namely, it inserts type + // information. There's no need to do this for maps. + obj.Set(reflect.Zero(t)) + } + + // Decode the dereferenced element; there is no need to wait here, as + // the interface object shares the current object state. + ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type) +} + +// decodeObject decodes a object value. +func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) { + ds.push(false, format, param) + ds.stats.Start(obj) + + switch x := object.GetValue().(type) { + case *pb.Object_BoolValue: + obj.SetBool(x.BoolValue) + case *pb.Object_StringValue: + obj.SetString(x.StringValue) + case *pb.Object_Int64Value: + obj.SetInt(x.Int64Value) + if obj.Int() != x.Int64Value { + panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type())) + } + case *pb.Object_Uint64Value: + obj.SetUint(x.Uint64Value) + if obj.Uint() != x.Uint64Value { + panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type())) + } + case *pb.Object_DoubleValue: + obj.SetFloat(x.DoubleValue) + if obj.Float() != x.DoubleValue { + panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type())) + } + case *pb.Object_RefValue: + // Resolve the pointer itself, even though the object may not + // be decoded yet. You need to use wait() in order to ensure + // that is the case. See wait above, and Map.Barrier. + if id := x.RefValue; id != 0 { + // Decoding the interface should have imparted type + // information, so from this point it's safe to resolve + // and use this dynamic information for actually + // creating the object in register. + // + // (For non-interfaces this is a no-op). + dyntyp := reflect.TypeOf(obj.Interface()) + if dyntyp.Kind() == reflect.Map { + obj.Set(ds.register(id, dyntyp).obj) + } else if dyntyp.Kind() == reflect.Ptr { + ds.push(true /* dereference */, "", nil) + obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) + ds.pop() + } else { + obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) + } + } else { + // We leave obj alone here. That's because if obj + // represents an interface, it may have been embued + // with type information in decodeInterface, and we + // don't want to destroy that information. + } + case *pb.Object_SliceValue: + // It's okay to slice the array here, since the contents will + // still be provided later on. These semantics are a bit + // strange but they are handled in the Map.Barrier properly. + // + // The special semantics of zero ref apply here too. + if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 { + v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem()) + obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity))) + } + case *pb.Object_ArrayValue: + ds.decodeArray(os, obj, x.ArrayValue) + case *pb.Object_StructValue: + ds.decodeStruct(os, obj, x.StructValue) + case *pb.Object_MapValue: + ds.decodeMap(os, obj, x.MapValue) + case *pb.Object_InterfaceValue: + ds.decodeInterface(os, obj, x.InterfaceValue) + case *pb.Object_ByteArrayValue: + copyArray(obj, reflect.ValueOf(x.ByteArrayValue)) + case *pb.Object_Uint16ArrayValue: + // 16-bit slices are serialized as 32-bit slices. + // See object.proto for details. + s := x.Uint16ArrayValue.Values + t := obj.Slice(0, obj.Len()).Interface().([]uint16) + if len(t) != len(s) { + panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) + } + for i := range s { + t[i] = uint16(s[i]) + } + case *pb.Object_Uint32ArrayValue: + copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values)) + case *pb.Object_Uint64ArrayValue: + copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values)) + case *pb.Object_UintptrArrayValue: + copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0)))) + case *pb.Object_Int8ArrayValue: + copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0)))) + case *pb.Object_Int16ArrayValue: + // 16-bit slices are serialized as 32-bit slices. + // See object.proto for details. + s := x.Int16ArrayValue.Values + t := obj.Slice(0, obj.Len()).Interface().([]int16) + if len(t) != len(s) { + panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) + } + for i := range s { + t[i] = int16(s[i]) + } + case *pb.Object_Int32ArrayValue: + copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values)) + case *pb.Object_Int64ArrayValue: + copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values)) + case *pb.Object_BoolArrayValue: + copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values)) + case *pb.Object_Float64ArrayValue: + copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values)) + case *pb.Object_Float32ArrayValue: + copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values)) + default: + // Shoud not happen, not propagated as an error. + panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type())) + } + + ds.stats.Done() + ds.pop() +} + +func copyArray(dest reflect.Value, src reflect.Value) { + if dest.Len() != src.Len() { + panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len())) + } + reflect.Copy(dest, castSlice(src, dest.Type().Elem())) +} + +// Deserialize deserializes the object state. +// +// This function may panic and should be run in safely(). +func (ds *decodeState) Deserialize(obj reflect.Value) { + ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()} + ds.outstanding = 1 // The root object. + + // Decode all objects in the stream. + // + // See above, we never process objects while we have no outstanding + // interests (other than the very first object). + for id := uint64(1); ds.outstanding > 0; id++ { + o, err := ds.readObject() + if err != nil { + panic(err) + } + + os := ds.lookup(id) + if os != nil { + // Decode the object. + ds.from = &os.path + ds.decodeObject(os, os.obj, o, "", nil) + ds.outstanding-- + } else { + // If an object hasn't had interest registered + // previously, we deferred decoding until interest is + // registered. + ds.deferred[id] = o + } + } + + // Check the zero-length header at the end. + length, object, err := ReadHeader(ds.r) + if err != nil { + panic(err) + } + if length != 0 { + panic(fmt.Sprintf("expected zero-length terminal, got %d", length)) + } + if object { + panic("expected non-object terminal") + } + + // Check if we have any deferred objects. + if count := len(ds.deferred); count > 0 { + // Shoud not happen, not propagated as an error. + panic(fmt.Sprintf("still have %d deferred objects", count)) + } + + // Scan and fire all callbacks. + for _, os := range ds.objectsByID { + os.checkComplete(ds.stats) + } + + // Check if we have any remaining dependency cycles. + for _, os := range ds.objectsByID { + if !os.complete() { + // This must be the result of a dependency cycle. + cycle := os.findCycle() + var buf bytes.Buffer + buf.WriteString("dependency cycle: {") + for i, cycleOS := range cycle { + if i > 0 { + buf.WriteString(" => ") + } + buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type())) + } + buf.WriteString("}") + // Panic as an error; propagate to the caller. + panic(errors.New(string(buf.Bytes()))) + } + } +} + +type byteReader struct { + io.Reader +} + +// ReadByte implements io.ByteReader. +func (br byteReader) ReadByte() (byte, error) { + var b [1]byte + n, err := br.Reader.Read(b[:]) + if n > 0 { + return b[0], nil + } else if err != nil { + return 0, err + } else { + return 0, io.ErrUnexpectedEOF + } +} + +// ReadHeader reads an object header. +// +// Each object written to the statefile is prefixed with a header. See +// WriteHeader for more information; these functions are exported to allow +// non-state writes to the file to play nice with debugging tools. +func ReadHeader(r io.Reader) (length uint64, object bool, err error) { + // Read the header. + length, err = binary.ReadUvarint(byteReader{r}) + if err != nil { + return + } + + // Decode whether the object is valid. + object = length&0x1 != 0 + length = length >> 1 + return +} + +// readObject reads an object from the stream. +func (ds *decodeState) readObject() (*pb.Object, error) { + // Read the header. + length, object, err := ReadHeader(ds.r) + if err != nil { + return nil, err + } + if !object { + return nil, fmt.Errorf("invalid object header") + } + + // Read the object. + buf := make([]byte, length) + for done := 0; done < len(buf); { + n, err := ds.r.Read(buf[done:]) + done += n + if n == 0 && err != nil { + return nil, err + } + } + + // Unmarshal. + obj := new(pb.Object) + if err := proto.Unmarshal(buf, obj); err != nil { + return nil, err + } + + return obj, nil +} diff --git a/pkg/state/encode.go b/pkg/state/encode.go new file mode 100644 index 000000000..eb6527afc --- /dev/null +++ b/pkg/state/encode.go @@ -0,0 +1,454 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "container/list" + "encoding/binary" + "fmt" + "io" + "reflect" + "sort" + + "github.com/golang/protobuf/proto" + pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto" +) + +// queuedObject is an object queued for encoding. +type queuedObject struct { + id uint64 + obj reflect.Value + path recoverable +} + +// encodeState is state used for encoding. +// +// The encoding process is a breadth-first traversal of the object graph. The +// inherent races and dependencies are much simpler than the decode case. +type encodeState struct { + // lastID is the last object ID. + // + // See idsByObject for context. Because of the special zero encoding + // used for reference values, the first ID must be 1. + lastID uint64 + + // idsByObject is a set of objects, indexed via: + // + // reflect.ValueOf(x).UnsafeAddr + // + // This provides IDs for objects. + idsByObject map[uintptr]uint64 + + // values stores values that span the addresses. + // + // addrSet is a a generated type which efficiently stores ranges of + // addresses. When encoding pointers, these ranges are filled in and + // used to check for overlapping or conflicting pointers. This would + // indicate a pointer to an field, or a non-type safe value, neither of + // which are currently decodable. + // + // See the usage of values below for more context. + values addrSet + + // w is the output stream. + w io.Writer + + // pending is the list of objects to be serialized. + // + // This is a set of queuedObjects. + pending list.List + + // done is the a list of finished objects. + // + // This is kept to prevent garbage collection and address reuse. + done list.List + + // stats is the passed stats object. + stats *Stats + + // recoverable is the panic recover facility. + recoverable +} + +// register looks up an ID, registering if necessary. +// +// If the object was not previosly registered, it is enqueued to be serialized. +// See the documentation for idsByObject for more information. +func (es *encodeState) register(obj reflect.Value) uint64 { + // It is not legal to call register for any non-pointer objects (see + // below), so we panic with a recoverable error if this is a mismatch. + if obj.Kind() != reflect.Ptr && obj.Kind() != reflect.Map { + panic(fmt.Errorf("non-pointer %#v registered", obj.Interface())) + } + + addr := obj.Pointer() + if obj.Kind() == reflect.Ptr && obj.Elem().Type().Size() == 0 { + // For zero-sized objects, we always provide a unique ID. + // That's because the runtime internally multiplexes pointers + // to the same address. We can't be certain what the intent is + // with pointers to zero-sized objects, so we just give them + // all unique identities. + } else if id, ok := es.idsByObject[addr]; ok { + // Already registered. + return id + } + + // Ensure that the first ID given out is one. See note on lastID. The + // ID zero is used to indicate nil values. + es.lastID++ + id := es.lastID + es.idsByObject[addr] = id + if obj.Kind() == reflect.Ptr { + // Dereference and treat as a pointer. + es.pending.PushBack(queuedObject{id: id, obj: obj.Elem(), path: es.recoverable.copy()}) + + // Register this object at all addresses. + typ := obj.Elem().Type() + if size := typ.Size(); size > 0 { + r := addrRange{addr, addr + size} + if !es.values.IsEmptyRange(r) { + panic(fmt.Errorf("overlapping objects: [new object] %#v [existing object] %#v", obj.Interface(), es.values.FindSegment(addr).Value().Elem().Interface())) + } + es.values.Add(r, obj) + } + } else { + // Push back the map itself; when maps are encoded from the + // top-level, forceMap will be equal to true. + es.pending.PushBack(queuedObject{id: id, obj: obj, path: es.recoverable.copy()}) + } + + return id +} + +// encodeMap encodes a map. +func (es *encodeState) encodeMap(obj reflect.Value) *pb.Map { + var ( + keys []*pb.Object + values []*pb.Object + ) + for i, k := range obj.MapKeys() { + v := obj.MapIndex(k) + kp := es.encodeObject(k, false, ".(key %d)", i) + vp := es.encodeObject(v, false, "[%#v]", k.Interface()) + keys = append(keys, kp) + values = append(values, vp) + } + return &pb.Map{Keys: keys, Values: values} +} + +// encodeStruct encodes a composite object. +func (es *encodeState) encodeStruct(obj reflect.Value) *pb.Struct { + // Invoke the save. + m := Map{newInternalMap(es, nil, nil)} + defer internalMapPool.Put(m.internalMap) + if !obj.CanAddr() { + // Force it to a * type of the above; this involves a copy. + localObj := reflect.New(obj.Type()) + localObj.Elem().Set(obj) + obj = localObj.Elem() + } + fns, ok := registeredTypes.lookupFns(obj.Addr().Type()) + if ok { + // Invoke the provided saver. + fns.invokeSave(obj.Addr(), m) + } else if obj.NumField() == 0 { + // Allow unregistered anonymous, empty structs. + return &pb.Struct{} + } else { + // Propagate an error. + panic(fmt.Errorf("unregistered type %T", obj.Interface())) + } + + // Sort the underlying slice, and check for duplicates. This is done + // once instead of on each add, because performing this sort once is + // far more efficient. + if len(m.data) > 1 { + sort.Slice(m.data, func(i, j int) bool { + return m.data[i].name < m.data[j].name + }) + for i := range m.data { + if i > 0 && m.data[i-1].name == m.data[i].name { + panic(fmt.Errorf("duplicate name %s", m.data[i].name)) + } + } + } + + // Encode the resulting fields. + fields := make([]*pb.Field, 0, len(m.data)) + for _, e := range m.data { + fields = append(fields, &pb.Field{ + Name: e.name, + Value: e.object, + }) + } + + // Return the encoded object. + return &pb.Struct{Fields: fields} +} + +// encodeArray encodes an array. +func (es *encodeState) encodeArray(obj reflect.Value) *pb.Array { + var ( + contents []*pb.Object + ) + for i := 0; i < obj.Len(); i++ { + entry := es.encodeObject(obj.Index(i), false, "[%d]", i) + contents = append(contents, entry) + } + return &pb.Array{Contents: contents} +} + +// encodeInterface encodes an interface. +// +// Precondition: the value is not nil. +func (es *encodeState) encodeInterface(obj reflect.Value) *pb.Interface { + // Check for the nil interface. + obj = reflect.ValueOf(obj.Interface()) + if !obj.IsValid() { + return &pb.Interface{ + Type: "", // left alone in decode. + Value: &pb.Object{Value: &pb.Object_RefValue{0}}, + } + } + // We have an interface value here. How do we save that? We + // resolve the underlying type and save it as a dispatchable. + typName, ok := registeredTypes.lookupName(obj.Type()) + if !ok { + panic(fmt.Errorf("type %s is not registered", obj.Type())) + } + + // Encode the object again. + return &pb.Interface{ + Type: typName, + Value: es.encodeObject(obj, false, ".(%s)", typName), + } +} + +// encodeObject encodes an object. +// +// If mapAsValue is true, then a map will be encoded directly. +func (es *encodeState) encodeObject(obj reflect.Value, mapAsValue bool, format string, param interface{}) (object *pb.Object) { + es.push(false, format, param) + es.stats.Start(obj) + + switch obj.Kind() { + case reflect.Bool: + object = &pb.Object{Value: &pb.Object_BoolValue{obj.Bool()}} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + object = &pb.Object{Value: &pb.Object_Int64Value{obj.Int()}} + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + object = &pb.Object{Value: &pb.Object_Uint64Value{obj.Uint()}} + case reflect.Float32, reflect.Float64: + object = &pb.Object{Value: &pb.Object_DoubleValue{obj.Float()}} + case reflect.Array: + switch obj.Type().Elem().Kind() { + case reflect.Uint8: + object = &pb.Object{Value: &pb.Object_ByteArrayValue{pbSlice(obj).Interface().([]byte)}} + case reflect.Uint16: + // 16-bit slices are serialized as 32-bit slices. + // See object.proto for details. + s := pbSlice(obj).Interface().([]uint16) + t := make([]uint32, len(s)) + for i := range s { + t[i] = uint32(s[i]) + } + object = &pb.Object{Value: &pb.Object_Uint16ArrayValue{&pb.Uint16S{Values: t}}} + case reflect.Uint32: + object = &pb.Object{Value: &pb.Object_Uint32ArrayValue{&pb.Uint32S{Values: pbSlice(obj).Interface().([]uint32)}}} + case reflect.Uint64: + object = &pb.Object{Value: &pb.Object_Uint64ArrayValue{&pb.Uint64S{Values: pbSlice(obj).Interface().([]uint64)}}} + case reflect.Uintptr: + object = &pb.Object{Value: &pb.Object_UintptrArrayValue{&pb.Uintptrs{Values: pbSlice(obj).Interface().([]uint64)}}} + case reflect.Int8: + object = &pb.Object{Value: &pb.Object_Int8ArrayValue{&pb.Int8S{Values: pbSlice(obj).Interface().([]byte)}}} + case reflect.Int16: + // 16-bit slices are serialized as 32-bit slices. + // See object.proto for details. + s := pbSlice(obj).Interface().([]int16) + t := make([]int32, len(s)) + for i := range s { + t[i] = int32(s[i]) + } + object = &pb.Object{Value: &pb.Object_Int16ArrayValue{&pb.Int16S{Values: t}}} + case reflect.Int32: + object = &pb.Object{Value: &pb.Object_Int32ArrayValue{&pb.Int32S{Values: pbSlice(obj).Interface().([]int32)}}} + case reflect.Int64: + object = &pb.Object{Value: &pb.Object_Int64ArrayValue{&pb.Int64S{Values: pbSlice(obj).Interface().([]int64)}}} + case reflect.Bool: + object = &pb.Object{Value: &pb.Object_BoolArrayValue{&pb.Bools{Values: pbSlice(obj).Interface().([]bool)}}} + case reflect.Float32: + object = &pb.Object{Value: &pb.Object_Float32ArrayValue{&pb.Float32S{Values: pbSlice(obj).Interface().([]float32)}}} + case reflect.Float64: + object = &pb.Object{Value: &pb.Object_Float64ArrayValue{&pb.Float64S{Values: pbSlice(obj).Interface().([]float64)}}} + default: + object = &pb.Object{Value: &pb.Object_ArrayValue{es.encodeArray(obj)}} + } + case reflect.Slice: + if obj.IsNil() || obj.Cap() == 0 { + // Handled specially in decode; store as nil value. + object = &pb.Object{Value: &pb.Object_RefValue{0}} + } else { + // Serialize a slice as the array plus length and capacity. + object = &pb.Object{Value: &pb.Object_SliceValue{&pb.Slice{ + Capacity: uint32(obj.Cap()), + Length: uint32(obj.Len()), + RefValue: es.register(arrayFromSlice(obj)), + }}} + } + case reflect.String: + object = &pb.Object{Value: &pb.Object_StringValue{obj.String()}} + case reflect.Ptr: + if obj.IsNil() { + // Handled specially in decode; store as a nil value. + object = &pb.Object{Value: &pb.Object_RefValue{0}} + } else { + es.push(true /* dereference */, "", nil) + object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}} + es.pop() + } + case reflect.Interface: + // We don't check for IsNil here, as we want to encode type + // information. The case of the empty interface (no type, no + // value) is handled by encodeInteface. + object = &pb.Object{Value: &pb.Object_InterfaceValue{es.encodeInterface(obj)}} + case reflect.Struct: + object = &pb.Object{Value: &pb.Object_StructValue{es.encodeStruct(obj)}} + case reflect.Map: + if obj.IsNil() { + // Handled specially in decode; store as a nil value. + object = &pb.Object{Value: &pb.Object_RefValue{0}} + } else if mapAsValue { + // Encode the map directly. + object = &pb.Object{Value: &pb.Object_MapValue{es.encodeMap(obj)}} + } else { + // Encode a reference to the map. + object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}} + } + default: + panic(fmt.Errorf("unknown primitive %#v", obj.Interface())) + } + + es.stats.Done() + es.pop() + return +} + +// Serialize serializes the object state. +// +// This function may panic and should be run in safely(). +func (es *encodeState) Serialize(obj reflect.Value) { + es.register(obj.Addr()) + + // Pop off the list until we're done. + for es.pending.Len() > 0 { + e := es.pending.Front() + es.pending.Remove(e) + + // Extract the queued object. + qo := e.Value.(queuedObject) + es.from = &qo.path + o := es.encodeObject(qo.obj, true, "", nil) + + // Emit to our output stream. + if err := es.writeObject(qo.id, o); err != nil { + panic(err) + } + + // Mark as done. + es.done.PushBack(e) + } + + // Write a zero-length terminal at the end; this is a sanity check + // applied at decode time as well (see decode.go). + if err := WriteHeader(es.w, 0, false); err != nil { + panic(err) + } +} + +// WriteHeader writes a header. +// +// Each object written to the statefile should be prefixed with a header. In +// order to generate statefiles that play nicely with debugging tools, raw +// writes should be prefixed with a header with object set to false and the +// appropriate length. This will allow tools to skip these regions. +func WriteHeader(w io.Writer, length uint64, object bool) error { + // The lowest-order bit encodes whether this is a valid object. This is + // a purely internal convention, but allows the object flag to be + // returned from ReadHeader. + length = length << 1 + if object { + length |= 0x1 + } + + // Write a header. + var hdr [32]byte + encodedLen := binary.PutUvarint(hdr[:], length) + for done := 0; done < encodedLen; { + n, err := w.Write(hdr[done:encodedLen]) + done += n + if n == 0 && err != nil { + return err + } + } + + return nil +} + +// writeObject writes an object to the stream. +func (es *encodeState) writeObject(id uint64, obj *pb.Object) error { + // Marshal the proto. + buf, err := proto.Marshal(obj) + if err != nil { + return err + } + + // Write the object header. + if err := WriteHeader(es.w, uint64(len(buf)), true); err != nil { + return err + } + + // Write the object. + for done := 0; done < len(buf); { + n, err := es.w.Write(buf[done:]) + done += n + if n == 0 && err != nil { + return err + } + } + + return nil +} + +// addrSetFunctions is used by addrSet. +type addrSetFunctions struct{} + +func (addrSetFunctions) MinKey() uintptr { + return 0 +} + +func (addrSetFunctions) MaxKey() uintptr { + return ^uintptr(0) +} + +func (addrSetFunctions) ClearValue(val *reflect.Value) { +} + +func (addrSetFunctions) Merge(_ addrRange, val1 reflect.Value, _ addrRange, val2 reflect.Value) (reflect.Value, bool) { + return val1, val1 == val2 +} + +func (addrSetFunctions) Split(_ addrRange, val reflect.Value, _ uintptr) (reflect.Value, reflect.Value) { + return val, val +} diff --git a/pkg/state/encode_unsafe.go b/pkg/state/encode_unsafe.go new file mode 100644 index 000000000..d96ba56d4 --- /dev/null +++ b/pkg/state/encode_unsafe.go @@ -0,0 +1,81 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "reflect" + "unsafe" +) + +// arrayFromSlice constructs a new pointer to the slice data. +// +// It would be similar to the following: +// +// x := make([]Foo, l, c) +// a := ([l]Foo*)(unsafe.Pointer(x[0])) +// +func arrayFromSlice(obj reflect.Value) reflect.Value { + return reflect.NewAt( + reflect.ArrayOf(obj.Cap(), obj.Type().Elem()), + unsafe.Pointer(obj.Pointer())) +} + +// pbSlice returns a protobuf-supported slice of the array and erase the +// original element type (which could be a defined type or non-supported type). +func pbSlice(obj reflect.Value) reflect.Value { + var typ reflect.Type + switch obj.Type().Elem().Kind() { + case reflect.Uint8: + typ = reflect.TypeOf(byte(0)) + case reflect.Uint16: + typ = reflect.TypeOf(uint16(0)) + case reflect.Uint32: + typ = reflect.TypeOf(uint32(0)) + case reflect.Uint64: + typ = reflect.TypeOf(uint64(0)) + case reflect.Uintptr: + typ = reflect.TypeOf(uint64(0)) + case reflect.Int8: + typ = reflect.TypeOf(byte(0)) + case reflect.Int16: + typ = reflect.TypeOf(int16(0)) + case reflect.Int32: + typ = reflect.TypeOf(int32(0)) + case reflect.Int64: + typ = reflect.TypeOf(int64(0)) + case reflect.Bool: + typ = reflect.TypeOf(bool(false)) + case reflect.Float32: + typ = reflect.TypeOf(float32(0)) + case reflect.Float64: + typ = reflect.TypeOf(float64(0)) + default: + panic("slice element is not of basic value type") + } + return reflect.NewAt( + reflect.ArrayOf(obj.Len(), typ), + unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()), + ).Elem().Slice(0, obj.Len()) +} + +func castSlice(obj reflect.Value, elemTyp reflect.Type) reflect.Value { + if obj.Type().Elem().Size() != elemTyp.Size() { + panic("cannot cast slice into other element type of different size") + } + return reflect.NewAt( + reflect.ArrayOf(obj.Len(), elemTyp), + unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()), + ).Elem() +} diff --git a/pkg/state/map.go b/pkg/state/map.go new file mode 100644 index 000000000..c3d165501 --- /dev/null +++ b/pkg/state/map.go @@ -0,0 +1,221 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "fmt" + "reflect" + "sort" + "sync" + + pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto" +) + +// entry is a single map entry. +type entry struct { + name string + object *pb.Object +} + +// internalMap is the internal Map state. +// +// These are recycled via a pool to avoid churn. +type internalMap struct { + // es is encodeState. + es *encodeState + + // ds is decodeState. + ds *decodeState + + // os is current object being decoded. + // + // This will always be nil during encode. + os *objectState + + // data stores the encoded values. + data []entry +} + +var internalMapPool = sync.Pool{ + New: func() interface{} { + return new(internalMap) + }, +} + +// newInternalMap returns a cached map. +func newInternalMap(es *encodeState, ds *decodeState, os *objectState) *internalMap { + m := internalMapPool.Get().(*internalMap) + m.es = es + m.ds = ds + m.os = os + if m.data != nil { + m.data = m.data[:0] + } + return m +} + +// Map is a generic state container. +// +// This is the object passed to Save and Load in order to store their state. +// +// Detailed documentation is available in individual methods. +type Map struct { + *internalMap +} + +// Save adds the given object to the map. +// +// You should pass always pointers to the object you are saving. For example: +// +// type X struct { +// A int +// B *int +// } +// +// func (x *X) Save(m Map) { +// m.Save("A", &x.A) +// m.Save("B", &x.B) +// } +// +// func (x *X) Load(m Map) { +// m.Load("A", &x.A) +// m.Load("B", &x.B) +// } +func (m Map) Save(name string, objPtr interface{}) { + m.save(name, reflect.ValueOf(objPtr).Elem(), ".%s") +} + +// SaveValue adds the given object value to the map. +// +// This should be used for values where pointers are not available, or casts +// are required during Save/Load. +// +// For example, if we want to cast external package type P.Foo to int64: +// +// type X struct { +// A P.Foo +// } +// +// func (x *X) Save(m Map) { +// m.SaveValue("A", int64(x.A)) +// } +// +// func (x *X) Load(m Map) { +// m.LoadValue("A", new(int64), func(x interface{}) { +// x.A = P.Foo(x.(int64)) +// }) +// } +func (m Map) SaveValue(name string, obj interface{}) { + m.save(name, reflect.ValueOf(obj), ".(value %s)") +} + +// save is helper for the above. It takes the name of value to save the field +// to, the field object (obj), and a format string that specifies how the +// field's saving logic is dispatched from the struct (normal, value, etc.). The +// format string should expect one string parameter, which is the name of the +// field. +func (m Map) save(name string, obj reflect.Value, format string) { + if m.es == nil { + // Not currently encoding. + m.Failf("no encode state for %q", name) + } + + // Attempt the encode. + // + // These are sorted at the end, after all objects are added and will be + // sorted and checked for duplicates (see encodeStruct). + m.data = append(m.data, entry{ + name: name, + object: m.es.encodeObject(obj, false, format, name), + }) +} + +// Load loads the given object from the map. +// +// See Save for an example. +func (m Map) Load(name string, objPtr interface{}) { + m.load(name, reflect.ValueOf(objPtr), false, nil, ".%s") +} + +// LoadWait loads the given objects from the map, and marks it as requiring all +// AfterLoad executions to complete prior to running this object's AfterLoad. +// +// See Save for an example. +func (m Map) LoadWait(name string, objPtr interface{}) { + m.load(name, reflect.ValueOf(objPtr), true, nil, ".(wait %s)") +} + +// LoadValue loads the given object value from the map. +// +// See SaveValue for an example. +func (m Map) LoadValue(name string, objPtr interface{}, fn func(interface{})) { + o := reflect.ValueOf(objPtr) + m.load(name, o, true, func() { fn(o.Elem().Interface()) }, ".(value %s)") +} + +// load is helper for the above. It takes the name of value to load the field +// from, the target field pointer (objPtr), whether load completion of the +// struct depends on the field's load completion (wait), the load completion +// logic (fn), and a format string that specifies how the field's loading logic +// is dispatched from the struct (normal, wait, value, etc.). The format string +// should expect one string parameter, which is the name of the field. +func (m Map) load(name string, objPtr reflect.Value, wait bool, fn func(), format string) { + if m.ds == nil { + // Not currently decoding. + m.Failf("no decode state for %q", name) + } + + // Find the object. + // + // These are sorted up front (and should appear in the state file + // sorted as well), so we can do a binary search here to ensure that + // large structs don't behave badly. + i := sort.Search(len(m.data), func(i int) bool { + return m.data[i].name >= name + }) + if i >= len(m.data) || m.data[i].name != name { + // There is no data for this name? + m.Failf("no data found for %q", name) + } + + // Perform the decode. + m.ds.decodeObject(m.os, objPtr.Elem(), m.data[i].object, format, name) + if wait { + // Mark this individual object a blocker. + m.ds.waitObject(m.os, m.data[i].object, fn) + } +} + +// Failf fails the save or restore with the provided message. Processing will +// stop after calling Failf, as the state package uses a panic & recover +// mechanism for state errors. You should defer any cleanup required. +func (m Map) Failf(format string, args ...interface{}) { + panic(fmt.Errorf(format, args...)) +} + +// AfterLoad schedules a function execution when all objects have been allocated +// and their automated loading and customized load logic have been executed. fn +// will not be executed until all of current object's dependencies' AfterLoad() +// logic, if exist, have been executed. +func (m Map) AfterLoad(fn func()) { + if m.ds == nil { + // Not currently decoding. + m.Failf("not decoding") + } + + // Queue the local callback; this will execute when all of the above + // data dependencies have been cleared. + m.os.callbacks = append(m.os.callbacks, fn) +} diff --git a/pkg/state/object.proto b/pkg/state/object.proto new file mode 100644 index 000000000..6595c5519 --- /dev/null +++ b/pkg/state/object.proto @@ -0,0 +1,140 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package gvisor.state.statefile; + +// Slice is a slice value. +message Slice { + uint32 length = 1; + uint32 capacity = 2; + uint64 ref_value = 3; +} + +// Array is an array value. +message Array { + repeated Object contents = 1; +} + +// Map is a map value. +message Map { + repeated Object keys = 1; + repeated Object values = 2; +} + +// Interface is an interface value. +message Interface { + string type = 1; + Object value = 2; +} + +// Struct is a basic composite value. +message Struct { + repeated Field fields = 1; +} + +// Field encodes a single field. +message Field { + string name = 1; + Object value = 2; +} + +// Uint16s encodes an uint16 array. To be used inside oneof structure. +message Uint16s { + // There is no 16-bit type in protobuf so we use variable length 32-bit here. + repeated uint32 values = 1; +} + +// Uint32s encodes an uint32 array. To be used inside oneof structure. +message Uint32s { + repeated fixed32 values = 1; +} + +// Uint64s encodes an uint64 array. To be used inside oneof structure. +message Uint64s { + repeated fixed64 values = 1; +} + +// Uintptrs encodes an uintptr array. To be used inside oneof structure. +message Uintptrs { + repeated fixed64 values = 1; +} + +// Int8s encodes an int8 array. To be used inside oneof structure. +message Int8s { + bytes values = 1; +} + +// Int16s encodes an int16 array. To be used inside oneof structure. +message Int16s { + // There is no 16-bit type in protobuf so we use variable length 32-bit here. + repeated int32 values = 1; +} + +// Int32s encodes an int32 array. To be used inside oneof structure. +message Int32s { + repeated sfixed32 values = 1; +} + +// Int64s encodes an int64 array. To be used inside oneof structure. +message Int64s { + repeated sfixed64 values = 1; +} + +// Bools encodes a boolean array. To be used inside oneof structure. +message Bools { + repeated bool values = 1; +} + +// Float64s encodes a float64 array. To be used inside oneof structure. +message Float64s { + repeated double values = 1; +} + +// Float32s encodes a float32 array. To be used inside oneof structure. +message Float32s { + repeated float values = 1; +} + +// Object are primitive encodings. +// +// Note that ref_value references an Object.id, below. +message Object { + oneof value { + bool bool_value = 1; + string string_value = 2; + int64 int64_value = 3; + uint64 uint64_value = 4; + double double_value = 5; + uint64 ref_value = 6; + Slice slice_value = 7; + Array array_value = 8; + Interface interface_value = 9; + Struct struct_value = 10; + Map map_value = 11; + bytes byte_array_value = 12; + Uint16s uint16_array_value = 13; + Uint32s uint32_array_value = 14; + Uint64s uint64_array_value = 15; + Uintptrs uintptr_array_value = 16; + Int8s int8_array_value = 17; + Int16s int16_array_value = 18; + Int32s int32_array_value = 19; + Int64s int64_array_value = 20; + Bools bool_array_value = 21; + Float64s float64_array_value = 22; + Float32s float32_array_value = 23; + } +} diff --git a/pkg/state/printer.go b/pkg/state/printer.go new file mode 100644 index 000000000..c61ec4a26 --- /dev/null +++ b/pkg/state/printer.go @@ -0,0 +1,188 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "fmt" + "io" + "io/ioutil" + "strings" + + "github.com/golang/protobuf/proto" + pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto" +) + +// format formats a single object, for pretty-printing. +func format(graph uint64, depth int, object *pb.Object, html bool) (string, bool) { + switch x := object.GetValue().(type) { + case *pb.Object_BoolValue: + return fmt.Sprintf("%t", x.BoolValue), x.BoolValue != false + case *pb.Object_StringValue: + return fmt.Sprintf("\"%s\"", x.StringValue), x.StringValue != "" + case *pb.Object_Int64Value: + return fmt.Sprintf("%d", x.Int64Value), x.Int64Value != 0 + case *pb.Object_Uint64Value: + return fmt.Sprintf("%du", x.Uint64Value), x.Uint64Value != 0 + case *pb.Object_DoubleValue: + return fmt.Sprintf("%f", x.DoubleValue), x.DoubleValue != 0.0 + case *pb.Object_RefValue: + if x.RefValue == 0 { + return "nil", false + } + ref := fmt.Sprintf("g%dr%d", graph, x.RefValue) + if html { + ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref) + } + return ref, true + case *pb.Object_SliceValue: + if x.SliceValue.RefValue == 0 { + return "nil", false + } + ref := fmt.Sprintf("g%dr%d", graph, x.SliceValue.RefValue) + if html { + ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref) + } + return fmt.Sprintf("%s[:%d:%d]", ref, x.SliceValue.Length, x.SliceValue.Capacity), true + case *pb.Object_ArrayValue: + if len(x.ArrayValue.Contents) == 0 { + return "[]", false + } + items := make([]string, 0, len(x.ArrayValue.Contents)+2) + zeros := make([]string, 0) // used to eliminate zero entries. + items = append(items, "[") + tabs := "\n" + strings.Repeat("\t", depth) + for i := 0; i < len(x.ArrayValue.Contents); i++ { + item, ok := format(graph, depth+1, x.ArrayValue.Contents[i], html) + if ok { + if len(zeros) > 0 { + items = append(items, zeros...) + zeros = nil + } + items = append(items, fmt.Sprintf("\t%s,", item)) + } else { + zeros = append(zeros, fmt.Sprintf("\t%s,", item)) + } + } + if len(zeros) > 0 { + items = append(items, fmt.Sprintf("\t... (%d zero),", len(zeros))) + } + items = append(items, "]") + return strings.Join(items, tabs), len(zeros) < len(x.ArrayValue.Contents) + case *pb.Object_StructValue: + if len(x.StructValue.Fields) == 0 { + return "struct{}", false + } + items := make([]string, 0, len(x.StructValue.Fields)+2) + items = append(items, "struct{") + tabs := "\n" + strings.Repeat("\t", depth) + allZero := true + for _, field := range x.StructValue.Fields { + element, ok := format(graph, depth+1, field.Value, html) + allZero = allZero && !ok + items = append(items, fmt.Sprintf("\t%s: %s,", field.Name, element)) + } + items = append(items, "}") + return strings.Join(items, tabs), !allZero + case *pb.Object_MapValue: + if len(x.MapValue.Keys) == 0 { + return "map{}", false + } + items := make([]string, 0, len(x.MapValue.Keys)+2) + items = append(items, "map{") + tabs := "\n" + strings.Repeat("\t", depth) + for i := 0; i < len(x.MapValue.Keys); i++ { + key, _ := format(graph, depth+1, x.MapValue.Keys[i], html) + value, _ := format(graph, depth+1, x.MapValue.Values[i], html) + items = append(items, fmt.Sprintf("\t%s: %s,", key, value)) + } + items = append(items, "}") + return strings.Join(items, tabs), true + case *pb.Object_InterfaceValue: + if x.InterfaceValue.Type == "" { + return "interface(nil){}", false + } + element, _ := format(graph, depth+1, x.InterfaceValue.Value, html) + return fmt.Sprintf("interface(\"%s\"){%s}", x.InterfaceValue.Type, element), true + } + + // Should not happen, but tolerate. + return fmt.Sprintf("(unknown proto type: %T)", object.GetValue()), true +} + +// PrettyPrint reads the state stream from r, and pretty prints to w. +func PrettyPrint(w io.Writer, r io.Reader, html bool) error { + var ( + // current graph ID. + graph uint64 + + // current object ID. + id uint64 + ) + + if html { + fmt.Fprintf(w, "<pre>") + defer fmt.Fprintf(w, "</pre>") + } + + for { + // Find the first object to begin generation. + length, object, err := ReadHeader(r) + if err == io.EOF { + // Nothing else to do. + break + } else if err != nil { + return err + } + if !object { + // Increment the graph number & reset the ID. + graph++ + id = 0 + if length > 0 { + fmt.Fprintf(w, "(%d bytes non-object data)\n", length) + io.Copy(ioutil.Discard, &io.LimitedReader{ + R: r, + N: int64(length), + }) + } + continue + } + + // Read & unmarshal the object. + buf := make([]byte, length) + for done := 0; done < len(buf); { + n, err := r.Read(buf[done:]) + done += n + if n == 0 && err != nil { + return err + } + } + obj := new(pb.Object) + if err := proto.Unmarshal(buf, obj); err != nil { + return err + } + + id++ // First object must be one. + str, _ := format(graph, 0, obj, html) + tag := fmt.Sprintf("g%dr%d", graph, id) + if html { + tag = fmt.Sprintf("<a name=%s>%s</a>", tag, tag) + } + if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/state/state.go b/pkg/state/state.go new file mode 100644 index 000000000..23a0b5922 --- /dev/null +++ b/pkg/state/state.go @@ -0,0 +1,349 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package state provides functionality related to saving and loading object +// graphs. For most types, it provides a set of default saving / loading logic +// that will be invoked automatically if custom logic is not defined. +// +// Kind Support +// ---- ------- +// Bool default +// Int default +// Int8 default +// Int16 default +// Int32 default +// Int64 default +// Uint default +// Uint8 default +// Uint16 default +// Uint32 default +// Uint64 default +// Float32 default +// Float64 default +// Complex64 custom +// Complex128 custom +// Array default +// Chan custom +// Func custom +// Interface custom +// Map default (*) +// Ptr default +// Slice default +// String default +// Struct custom +// UnsafePointer custom +// +// (*) Maps are treated as value types by this package, even if they are +// pointers internally. If you want to save two independent references +// to the same map value, you must explicitly use a pointer to a map. +package state + +import ( + "fmt" + "io" + "reflect" + "runtime" + + pb "gvisor.googlesource.com/gvisor/pkg/state/object_go_proto" +) + +// ErrState is returned when an error is encountered during encode/decode. +type ErrState struct { + // Err is the underlying error. + Err error + + // path is the visit path from root to the current object. + path string + + // trace is the stack trace. + trace string +} + +// Error returns a sensible description of the state error. +func (e *ErrState) Error() string { + return fmt.Sprintf("%v:\nstate path: %s\n%s", e.Err, e.path, e.trace) +} + +// Save saves the given object state. +func Save(w io.Writer, rootPtr interface{}, stats *Stats) error { + // Create the encoding state. + es := &encodeState{ + idsByObject: make(map[uintptr]uint64), + w: w, + stats: stats, + } + + // Perform the encoding. + return es.safely(func() { + es.Serialize(reflect.ValueOf(rootPtr).Elem()) + }) +} + +// Load loads a checkpoint. +func Load(r io.Reader, rootPtr interface{}, stats *Stats) error { + // Create the decoding state. + ds := &decodeState{ + objectsByID: make(map[uint64]*objectState), + deferred: make(map[uint64]*pb.Object), + r: r, + stats: stats, + } + + // Attempt our decode. + return ds.safely(func() { + ds.Deserialize(reflect.ValueOf(rootPtr).Elem()) + }) +} + +// Fns are the state dispatch functions. +type Fns struct { + // Save is a function like Save(concreteType, Map). + Save interface{} + + // Load is a function like Load(concreteType, Map). + Load interface{} +} + +// Save executes the save function. +func (fns *Fns) invokeSave(obj reflect.Value, m Map) { + reflect.ValueOf(fns.Save).Call([]reflect.Value{obj, reflect.ValueOf(m)}) +} + +// Load executes the load function. +func (fns *Fns) invokeLoad(obj reflect.Value, m Map) { + reflect.ValueOf(fns.Load).Call([]reflect.Value{obj, reflect.ValueOf(m)}) +} + +// validateStateFn ensures types are correct. +func validateStateFn(fn interface{}, typ reflect.Type) bool { + fnTyp := reflect.TypeOf(fn) + if fnTyp.Kind() != reflect.Func { + return false + } + if fnTyp.NumIn() != 2 { + return false + } + if fnTyp.NumOut() != 0 { + return false + } + if fnTyp.In(0) != typ { + return false + } + if fnTyp.In(1) != reflect.TypeOf(Map{}) { + return false + } + return true +} + +// Validate validates all state functions. +func (fns *Fns) Validate(typ reflect.Type) bool { + return validateStateFn(fns.Save, typ) && validateStateFn(fns.Load, typ) +} + +type typeDatabase struct { + // nameToType is a forward lookup table. + nameToType map[string]reflect.Type + + // typeToName is the reverse lookup table. + typeToName map[reflect.Type]string + + // typeToFns is the function lookup table. + typeToFns map[reflect.Type]Fns +} + +// registeredTypes is a database used for SaveInterface and LoadInterface. +var registeredTypes = typeDatabase{ + nameToType: make(map[string]reflect.Type), + typeToName: make(map[reflect.Type]string), + typeToFns: make(map[reflect.Type]Fns), +} + +// register registers a type under the given name. This will generally be +// called via init() methods, and therefore uses panic to propagate errors. +func (t *typeDatabase) register(name string, typ reflect.Type, fns Fns) { + // We can't allow name collisions. + if ot, ok := t.nameToType[name]; ok { + panic(fmt.Sprintf("type %q can't use name %q, already in use by type %q", typ.Name(), name, ot.Name())) + } + + // Or multiple registrations. + if on, ok := t.typeToName[typ]; ok { + panic(fmt.Sprintf("type %q can't be registered as %q, already registered as %q", typ.Name(), name, on)) + } + + t.nameToType[name] = typ + t.typeToName[typ] = name + t.typeToFns[typ] = fns +} + +// lookupType finds a type given a name. +func (t *typeDatabase) lookupType(name string) (reflect.Type, bool) { + typ, ok := t.nameToType[name] + return typ, ok +} + +// lookupName finds a name given a type. +func (t *typeDatabase) lookupName(typ reflect.Type) (string, bool) { + name, ok := t.typeToName[typ] + return name, ok +} + +// lookupFns finds functions given a type. +func (t *typeDatabase) lookupFns(typ reflect.Type) (Fns, bool) { + fns, ok := t.typeToFns[typ] + return fns, ok +} + +// Register must be called for any interface implementation types that +// implements Loader. +// +// Register should be called either immediately after startup or via init() +// methods. Double registration of either names or types will result in a panic. +// +// No synchronization is provided; this should only be called in init. +// +// Example usage: +// +// state.Register("Foo", (*Foo)(nil), state.Fns{ +// Save: (*Foo).Save, +// Load: (*Foo).Load, +// }) +// +func Register(name string, instance interface{}, fns Fns) { + registeredTypes.register(name, reflect.TypeOf(instance), fns) +} + +// IsZeroValue checks if the given value is the zero value. +// +// This function is used by the stateify tool. +func IsZeroValue(val interface{}) bool { + if val == nil { + return true + } + return reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface()) +} + +// step captures one encoding / decoding step. On each step, there is up to one +// choice made, which is captured by non-nil param. We intentionally do not +// eagerly create the final path string, as that will only be needed upon panic. +type step struct { + // dereference indicate if the current object is obtained by + // dereferencing a pointer. + dereference bool + + // format is the formatting string that takes param below, if + // non-nil. For example, in array indexing case, we have "[%d]". + format string + + // param stores the choice made at the current encoding / decoding step. + // For eaxmple, in array indexing case, param stores the index. When no + // choice is made, e.g. dereference, param should be nil. + param interface{} +} + +// recoverable is the state encoding / decoding panic recovery facility. It is +// also used to store encoding / decoding steps as well as the reference to the +// original queued object from which the current object is dispatched. The +// complete encoding / decoding path is synthesised from the steps in all queued +// objects leading to the current object. +type recoverable struct { + from *recoverable + steps []step +} + +// push enters a new context level. +func (sr *recoverable) push(dereference bool, format string, param interface{}) { + sr.steps = append(sr.steps, step{dereference, format, param}) +} + +// pop exits the current context level. +func (sr *recoverable) pop() { + if len(sr.steps) <= 1 { + return + } + sr.steps = sr.steps[:len(sr.steps)-1] +} + +// path returns the complete encoding / decoding path from root. This is only +// called upon panic. +func (sr *recoverable) path() string { + if sr.from == nil { + return "root" + } + p := sr.from.path() + for _, s := range sr.steps { + if s.dereference { + p = fmt.Sprintf("*(%s)", p) + } + if s.param == nil { + p += s.format + } else { + p += fmt.Sprintf(s.format, s.param) + } + } + return p +} + +func (sr *recoverable) copy() recoverable { + return recoverable{from: sr.from, steps: append([]step(nil), sr.steps...)} +} + +// safely executes the given function, catching a panic and unpacking as an error. +// +// The error flow through the state package uses panic and recover. There are +// two important reasons for this: +// +// 1) Many of the reflection methods will already panic with invalid data or +// violated assumptions. We would want to recover anyways here. +// +// 2) It allows us to eliminate boilerplate within Save() and Load() functions. +// In nearly all cases, when the low-level serialization functions fail, you +// will want the checkpoint to fail anyways. Plumbing errors through every +// method doesn't add a lot of value. If there are specific error conditions +// that you'd like to handle, you should add appropriate functionality to +// objects themselves prior to calling Save() and Load(). +func (sr *recoverable) safely(fn func()) (err error) { + defer func() { + if r := recover(); r != nil { + es := new(ErrState) + if e, ok := r.(error); ok { + es.Err = e + } else { + es.Err = fmt.Errorf("%v", r) + } + + es.path = sr.path() + + // Make a stack. We don't know how big it will be ahead + // of time, but want to make sure we get the whole + // thing. So we just do a stupid brute force approach. + var stack []byte + for sz := 1024; ; sz *= 2 { + stack = make([]byte, sz) + n := runtime.Stack(stack, false) + if n < sz { + es.trace = string(stack[:n]) + break + } + } + + // Set the error. + err = es + } + }() + + // Execute the function. + fn() + return nil +} diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go new file mode 100644 index 000000000..d5a739f18 --- /dev/null +++ b/pkg/state/state_test.go @@ -0,0 +1,719 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "bytes" + "io/ioutil" + "math" + "reflect" + "testing" +) + +// TestCase is used to define a single success/failure testcase of +// serialization of a set of objects. +type TestCase struct { + // Name is the name of the test case. + Name string + + // Objects is the list of values to serialize. + Objects []interface{} + + // Fail is whether the test case is supposed to fail or not. + Fail bool +} + +// runTest runs all testcases. +func runTest(t *testing.T, tests []TestCase) { + for _, test := range tests { + t.Logf("TEST %s:", test.Name) + for i, root := range test.Objects { + t.Logf(" case#%d: %#v", i, root) + + // Save the passed object. + saveBuffer := &bytes.Buffer{} + saveObjectPtr := reflect.New(reflect.TypeOf(root)) + saveObjectPtr.Elem().Set(reflect.ValueOf(root)) + if err := Save(saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { + t.Errorf(" FAIL: Save failed unexpectedly: %v", err) + continue + } else if err != nil { + t.Logf(" PASS: Save failed as expected: %v", err) + continue + } + + // Load a new copy of the object. + loadObjectPtr := reflect.New(reflect.TypeOf(root)) + if err := Load(bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { + t.Errorf(" FAIL: Load failed unexpectedly: %v", err) + continue + } else if err != nil { + t.Logf(" PASS: Load failed as expected: %v", err) + continue + } + + // Compare the values. + loadedValue := loadObjectPtr.Elem().Interface() + if eq := reflect.DeepEqual(root, loadedValue); !eq && !test.Fail { + t.Errorf(" FAIL: Objects differs; got %#v", loadedValue) + continue + } else if !eq { + t.Logf(" PASS: Object different as expected.") + continue + } + + // Everything went okay. Is that good? + if test.Fail { + t.Errorf(" FAIL: Unexpected success.") + } else { + t.Logf(" PASS: Success.") + } + } + } +} + +// dumbStruct is a struct which does not implement the loader/saver interface. +// We expect that serialization of this struct will fail. +type dumbStruct struct { + A int + B int +} + +// smartStruct is a struct which does implement the loader/saver interface. +// We expect that serialization of this struct will succeed. +type smartStruct struct { + A int + B int +} + +func (s *smartStruct) save(m Map) { + m.Save("A", &s.A) + m.Save("B", &s.B) +} + +func (s *smartStruct) load(m Map) { + m.Load("A", &s.A) + m.Load("B", &s.B) +} + +// valueLoadStruct uses a value load. +type valueLoadStruct struct { + v int +} + +func (v *valueLoadStruct) save(m Map) { + m.SaveValue("v", v.v) +} + +func (v *valueLoadStruct) load(m Map) { + m.LoadValue("v", new(int), func(value interface{}) { + v.v = value.(int) + }) +} + +// afterLoadStruct has an AfterLoad function. +type afterLoadStruct struct { + v int +} + +func (a *afterLoadStruct) save(m Map) { +} + +func (a *afterLoadStruct) load(m Map) { + m.AfterLoad(func() { + a.v++ + }) +} + +// genericContainer is a generic dispatcher. +type genericContainer struct { + v interface{} +} + +func (g *genericContainer) save(m Map) { + m.Save("v", &g.v) +} + +func (g *genericContainer) load(m Map) { + m.Load("v", &g.v) +} + +// sliceContainer is a generic slice. +type sliceContainer struct { + v []interface{} +} + +func (s *sliceContainer) save(m Map) { + m.Save("v", &s.v) +} + +func (s *sliceContainer) load(m Map) { + m.Load("v", &s.v) +} + +// mapContainer is a generic map. +type mapContainer struct { + v map[int]interface{} +} + +func (mc *mapContainer) save(m Map) { + m.Save("v", &mc.v) +} + +func (mc *mapContainer) load(m Map) { + // Some of the test cases below assume legacy behavior wherein maps + // will automatically inherit dependencies. + m.LoadWait("v", &mc.v) +} + +// dumbMap is a map which does not implement the loader/saver interface. +// Serialization of this map will default to the standard encode/decode logic. +type dumbMap map[string]int + +// pointerStruct contains various pointers, shared and non-shared, and pointers +// to pointers. We expect that serialization will respect the structure. +type pointerStruct struct { + A *int + B *int + C *int + D *int + + AA **int + BB **int +} + +func (p *pointerStruct) save(m Map) { + m.Save("A", &p.A) + m.Save("B", &p.B) + m.Save("C", &p.C) + m.Save("D", &p.D) + m.Save("AA", &p.AA) + m.Save("BB", &p.BB) +} + +func (p *pointerStruct) load(m Map) { + m.Load("A", &p.A) + m.Load("B", &p.B) + m.Load("C", &p.C) + m.Load("D", &p.D) + m.Load("AA", &p.AA) + m.Load("BB", &p.BB) +} + +// testInterface is a trivial interface example. +type testInterface interface { + Foo() +} + +// testImpl is a trivial implementation of testInterface. +type testImpl struct { +} + +// Foo satisfies testInterface. +func (t *testImpl) Foo() { +} + +// testImpl is trivially serializable. +func (t *testImpl) save(m Map) { +} + +// testImpl is trivially serializable. +func (t *testImpl) load(m Map) { +} + +// testI demonstrates interface dispatching. +type testI struct { + I testInterface +} + +func (t *testI) save(m Map) { + m.Save("I", &t.I) +} + +func (t *testI) load(m Map) { + m.Load("I", &t.I) +} + +// cycleStruct is used to implement basic cycles. +type cycleStruct struct { + c *cycleStruct +} + +func (c *cycleStruct) save(m Map) { + m.Save("c", &c.c) +} + +func (c *cycleStruct) load(m Map) { + m.Load("c", &c.c) +} + +// badCycleStruct actually has deadlocking dependencies. +// +// This should pass if b.b = {nil|b} and fail otherwise. +type badCycleStruct struct { + b *badCycleStruct +} + +func (b *badCycleStruct) save(m Map) { + m.Save("b", &b.b) +} + +func (b *badCycleStruct) load(m Map) { + m.LoadWait("b", &b.b) + m.AfterLoad(func() { + // This is not executable, since AfterLoad requires that the + // object and all dependencies are complete. This should cause + // a deadlock error during load. + }) +} + +// emptyStructPointer points to an empty struct. +type emptyStructPointer struct { + nothing *struct{} +} + +func (e *emptyStructPointer) save(m Map) { + m.Save("nothing", &e.nothing) +} + +func (e *emptyStructPointer) load(m Map) { + m.Load("nothing", &e.nothing) +} + +// truncateInteger truncates an integer. +type truncateInteger struct { + v int64 + v2 int32 +} + +func (t *truncateInteger) save(m Map) { + t.v2 = int32(t.v) + m.Save("v", &t.v) +} + +func (t *truncateInteger) load(m Map) { + m.Load("v", &t.v2) + t.v = int64(t.v2) +} + +// truncateUnsignedInteger truncates an unsigned integer. +type truncateUnsignedInteger struct { + v uint64 + v2 uint32 +} + +func (t *truncateUnsignedInteger) save(m Map) { + t.v2 = uint32(t.v) + m.Save("v", &t.v) +} + +func (t *truncateUnsignedInteger) load(m Map) { + m.Load("v", &t.v2) + t.v = uint64(t.v2) +} + +// truncateFloat truncates a floating point number. +type truncateFloat struct { + v float64 + v2 float32 +} + +func (t *truncateFloat) save(m Map) { + t.v2 = float32(t.v) + m.Save("v", &t.v) +} + +func (t *truncateFloat) load(m Map) { + m.Load("v", &t.v2) + t.v = float64(t.v2) +} + +func TestTypes(t *testing.T) { + // x and y are basic integers, while xp points to x. + x := 1 + y := 2 + xp := &x + + // cs is a single object cycle. + cs := cycleStruct{nil} + cs.c = &cs + + // cs1 and cs2 are in a two object cycle. + cs1 := cycleStruct{nil} + cs2 := cycleStruct{nil} + cs1.c = &cs2 + cs2.c = &cs1 + + // bs is a single object cycle. + bs := badCycleStruct{nil} + bs.b = &bs + + // bs2 and bs2 are in a deadlocking cycle. + bs1 := badCycleStruct{nil} + bs2 := badCycleStruct{nil} + bs1.b = &bs2 + bs2.b = &bs1 + + // regular nils. + var ( + nilmap dumbMap + nilslice []byte + ) + + // embed points to embedded fields. + embed1 := pointerStruct{} + embed1.AA = &embed1.A + embed2 := pointerStruct{} + embed2.BB = &embed2.B + + // es1 contains two structs pointing to the same empty struct. + es := emptyStructPointer{new(struct{})} + es1 := []emptyStructPointer{es, es} + + tests := []TestCase{ + { + Name: "bool", + Objects: []interface{}{ + true, + false, + }, + }, + { + Name: "integers", + Objects: []interface{}{ + int(0), + int(1), + int(-1), + int8(0), + int8(1), + int8(-1), + int16(0), + int16(1), + int16(-1), + int32(0), + int32(1), + int32(-1), + int64(0), + int64(1), + int64(-1), + }, + }, + { + Name: "unsigned integers", + Objects: []interface{}{ + uint(0), + uint(1), + uint8(0), + uint8(1), + uint16(0), + uint16(1), + uint32(1), + uint64(0), + uint64(1), + }, + }, + { + Name: "strings", + Objects: []interface{}{ + "", + "foo", + "bar", + }, + }, + { + Name: "slices", + Objects: []interface{}{ + []int{-1, 0, 1}, + []*int{&x, &x, &x}, + []int{1, 2, 3}[0:1], + []int{1, 2, 3}[1:2], + make([]byte, 32), + make([]byte, 32)[:16], + make([]byte, 32)[:16:20], + nilslice, + }, + }, + { + Name: "arrays", + Objects: []interface{}{ + &[1048576]bool{false, true, false, true}, + &[1048576]uint8{0, 1, 2, 3}, + &[1048576]byte{0, 1, 2, 3}, + &[1048576]uint16{0, 1, 2, 3}, + &[1048576]uint{0, 1, 2, 3}, + &[1048576]uint32{0, 1, 2, 3}, + &[1048576]uint64{0, 1, 2, 3}, + &[1048576]uintptr{0, 1, 2, 3}, + &[1048576]int8{0, -1, -2, -3}, + &[1048576]int16{0, -1, -2, -3}, + &[1048576]int32{0, -1, -2, -3}, + &[1048576]int64{0, -1, -2, -3}, + &[1048576]float32{0, 1.1, 2.2, 3.3}, + &[1048576]float64{0, 1.1, 2.2, 3.3}, + }, + }, + { + Name: "pointers", + Objects: []interface{}{ + &pointerStruct{A: &x, B: &x, C: &y, D: &y, AA: &xp, BB: &xp}, + &pointerStruct{}, + }, + }, + { + Name: "empty struct", + Objects: []interface{}{ + struct{}{}, + }, + }, + { + Name: "unenlightened structs", + Objects: []interface{}{ + &dumbStruct{A: 1, B: 2}, + }, + Fail: true, + }, + { + Name: "enlightened structs", + Objects: []interface{}{ + &smartStruct{A: 1, B: 2}, + }, + }, + { + Name: "load-hooks", + Objects: []interface{}{ + &afterLoadStruct{v: 1}, + &valueLoadStruct{v: 1}, + &genericContainer{v: &afterLoadStruct{v: 1}}, + &genericContainer{v: &valueLoadStruct{v: 1}}, + &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}}, + &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}}, + &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}}, + &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}}, + }, + }, + { + Name: "maps", + Objects: []interface{}{ + dumbMap{"a": -1, "b": 0, "c": 1}, + map[smartStruct]int{{}: 0, {A: 1}: 1}, + nilmap, + &mapContainer{v: map[int]interface{}{0: &smartStruct{A: 1}}}, + }, + }, + { + Name: "interfaces", + Objects: []interface{}{ + &testI{&testImpl{}}, + &testI{nil}, + &testI{(*testImpl)(nil)}, + }, + }, + { + Name: "unregistered-interfaces", + Objects: []interface{}{ + &genericContainer{v: afterLoadStruct{v: 1}}, + &genericContainer{v: valueLoadStruct{v: 1}}, + &sliceContainer{v: []interface{}{afterLoadStruct{v: 1}}}, + &sliceContainer{v: []interface{}{valueLoadStruct{v: 1}}}, + &mapContainer{v: map[int]interface{}{0: afterLoadStruct{v: 1}}}, + &mapContainer{v: map[int]interface{}{0: valueLoadStruct{v: 1}}}, + }, + Fail: true, + }, + { + Name: "cycles", + Objects: []interface{}{ + &cs, + &cs1, + &cycleStruct{&cs1}, + &cycleStruct{&cs}, + &badCycleStruct{nil}, + &bs, + }, + }, + { + Name: "deadlock", + Objects: []interface{}{ + &bs1, + }, + Fail: true, + }, + { + Name: "embed", + Objects: []interface{}{ + &embed1, + &embed2, + }, + Fail: true, + }, + { + Name: "empty structs", + Objects: []interface{}{ + new(struct{}), + es, + es1, + }, + }, + { + Name: "truncated okay", + Objects: []interface{}{ + &truncateInteger{v: 1}, + &truncateUnsignedInteger{v: 1}, + &truncateFloat{v: 1.0}, + }, + }, + { + Name: "truncated bad", + Objects: []interface{}{ + &truncateInteger{v: math.MaxInt32 + 1}, + &truncateUnsignedInteger{v: math.MaxUint32 + 1}, + &truncateFloat{v: math.MaxFloat32 * 2}, + }, + Fail: true, + }, + } + + runTest(t, tests) +} + +// benchStruct is used for benchmarking. +type benchStruct struct { + b *benchStruct + + // Dummy data is included to ensure that these objects are large. + // This is to detect possible regression when registering objects. + _ [4096]byte +} + +func (b *benchStruct) save(m Map) { + m.Save("b", &b.b) +} + +func (b *benchStruct) load(m Map) { + m.LoadWait("b", &b.b) + m.AfterLoad(b.afterLoad) +} + +func (b *benchStruct) afterLoad() { + // Do nothing, just force scheduling. +} + +// buildObject builds a benchmark object. +func buildObject(n int) (b *benchStruct) { + for i := 0; i < n; i++ { + b = &benchStruct{b: b} + } + return +} + +func BenchmarkEncoding(b *testing.B) { + b.StopTimer() + bs := buildObject(b.N) + var stats Stats + b.StartTimer() + if err := Save(ioutil.Discard, bs, &stats); err != nil { + b.Errorf("save failed: %v", err) + } + b.StopTimer() + if b.N > 1000 { + b.Logf("breakdown (n=%d): %s", b.N, &stats) + } +} + +func BenchmarkDecoding(b *testing.B) { + b.StopTimer() + bs := buildObject(b.N) + var newBS benchStruct + buf := &bytes.Buffer{} + if err := Save(buf, bs, nil); err != nil { + b.Errorf("save failed: %v", err) + } + var stats Stats + b.StartTimer() + if err := Load(buf, &newBS, &stats); err != nil { + b.Errorf("load failed: %v", err) + } + b.StopTimer() + if b.N > 1000 { + b.Logf("breakdown (n=%d): %s", b.N, &stats) + } +} + +func init() { + Register("stateTest.smartStruct", (*smartStruct)(nil), Fns{ + Save: (*smartStruct).save, + Load: (*smartStruct).load, + }) + Register("stateTest.afterLoadStruct", (*afterLoadStruct)(nil), Fns{ + Save: (*afterLoadStruct).save, + Load: (*afterLoadStruct).load, + }) + Register("stateTest.valueLoadStruct", (*valueLoadStruct)(nil), Fns{ + Save: (*valueLoadStruct).save, + Load: (*valueLoadStruct).load, + }) + Register("stateTest.genericContainer", (*genericContainer)(nil), Fns{ + Save: (*genericContainer).save, + Load: (*genericContainer).load, + }) + Register("stateTest.sliceContainer", (*sliceContainer)(nil), Fns{ + Save: (*sliceContainer).save, + Load: (*sliceContainer).load, + }) + Register("stateTest.mapContainer", (*mapContainer)(nil), Fns{ + Save: (*mapContainer).save, + Load: (*mapContainer).load, + }) + Register("stateTest.pointerStruct", (*pointerStruct)(nil), Fns{ + Save: (*pointerStruct).save, + Load: (*pointerStruct).load, + }) + Register("stateTest.testImpl", (*testImpl)(nil), Fns{ + Save: (*testImpl).save, + Load: (*testImpl).load, + }) + Register("stateTest.testI", (*testI)(nil), Fns{ + Save: (*testI).save, + Load: (*testI).load, + }) + Register("stateTest.cycleStruct", (*cycleStruct)(nil), Fns{ + Save: (*cycleStruct).save, + Load: (*cycleStruct).load, + }) + Register("stateTest.badCycleStruct", (*badCycleStruct)(nil), Fns{ + Save: (*badCycleStruct).save, + Load: (*badCycleStruct).load, + }) + Register("stateTest.emptyStructPointer", (*emptyStructPointer)(nil), Fns{ + Save: (*emptyStructPointer).save, + Load: (*emptyStructPointer).load, + }) + Register("stateTest.truncateInteger", (*truncateInteger)(nil), Fns{ + Save: (*truncateInteger).save, + Load: (*truncateInteger).load, + }) + Register("stateTest.truncateUnsignedInteger", (*truncateUnsignedInteger)(nil), Fns{ + Save: (*truncateUnsignedInteger).save, + Load: (*truncateUnsignedInteger).load, + }) + Register("stateTest.truncateFloat", (*truncateFloat)(nil), Fns{ + Save: (*truncateFloat).save, + Load: (*truncateFloat).load, + }) + Register("stateTest.benchStruct", (*benchStruct)(nil), Fns{ + Save: (*benchStruct).save, + Load: (*benchStruct).load, + }) +} diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD new file mode 100644 index 000000000..df2c6a578 --- /dev/null +++ b/pkg/state/statefile/BUILD @@ -0,0 +1,23 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "statefile", + srcs = ["statefile.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/state/statefile", + visibility = ["//:sandbox"], + deps = [ + "//pkg/binary", + "//pkg/compressio", + "//pkg/hashio", + ], +) + +go_test( + name = "statefile_test", + size = "small", + srcs = ["statefile_test.go"], + embed = [":statefile"], + deps = ["//pkg/hashio"], +) diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go new file mode 100644 index 000000000..b25b743b7 --- /dev/null +++ b/pkg/state/statefile/statefile.go @@ -0,0 +1,233 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package statefile defines the state file data stream. +// +// This package currently does not include any details regarding the state +// encoding itself, only details regarding state metadata and data layout. +// +// The file format is defined as follows. +// +// /------------------------------------------------------\ +// | header (8-bytes) | +// +------------------------------------------------------+ +// | metadata length (8-bytes) | +// +------------------------------------------------------+ +// | metadata | +// +------------------------------------------------------+ +// | data | +// \------------------------------------------------------/ +// +// First, it includes a 8-byte magic header which is the following +// sequence of bytes [0x67, 0x56, 0x69, 0x73, 0x6f, 0x72, 0x53, 0x46] +// +// This header is followed by an 8-byte length N (big endian), and an +// ASCII-encoded JSON map that is exactly N bytes long. +// +// This map includes only strings for keys and strings for values. Keys in the +// map that begin with "_" are for internal use only. They may be read, but may +// not be provided by the user. In the future, this metadata may contain some +// information relating to the state encoding itself. +// +// After the map, the remainder of the file is the state data. +package statefile + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/json" + "fmt" + "hash" + "io" + "strings" + "time" + + "gvisor.googlesource.com/gvisor/pkg/binary" + "gvisor.googlesource.com/gvisor/pkg/compressio" + "gvisor.googlesource.com/gvisor/pkg/hashio" +) + +// keySize is the AES-256 key length. +const keySize = 32 + +// compressionChunkSize is the chunk size for compression. +const compressionChunkSize = 1024 * 1024 + +// maxMetadataSize is the size limit of metadata section. +const maxMetadataSize = 16 * 1024 * 1024 + +// magicHeader is the byte sequence beginning each file. +var magicHeader = []byte("\x67\x56\x69\x73\x6f\x72\x53\x46") + +// ErrBadMagic is returned if the header does not match. +var ErrBadMagic = fmt.Errorf("bad magic header") + +// ErrMetadataMissing is returned if the state file is missing mandatory metadata. +var ErrMetadataMissing = fmt.Errorf("missing metadata") + +// ErrInvalidMetadataLength is returned if the metadata length is too large. +var ErrInvalidMetadataLength = fmt.Errorf("metadata length invalid, maximum size is %d", maxMetadataSize) + +// ErrMetadataInvalid is returned if passed metadata is invalid. +var ErrMetadataInvalid = fmt.Errorf("metadata invalid, can't start with _") + +// NewWriter returns a state data writer for a statefile. +// +// Note that the returned WriteCloser must be closed. +func NewWriter(w io.Writer, key []byte, metadata map[string]string, compressionLevel int) (io.WriteCloser, error) { + if metadata == nil { + metadata = make(map[string]string) + } + for k := range metadata { + if strings.HasPrefix(k, "_") { + return nil, ErrMetadataInvalid + } + } + + // Create our HMAC function. + h := hmac.New(sha256.New, key) + mw := io.MultiWriter(w, h) + + // First, write the header. + if _, err := mw.Write(magicHeader); err != nil { + return nil, err + } + + // Generate a timestamp, for convenience only. + metadata["_timestamp"] = time.Now().UTC().String() + defer delete(metadata, "_timestamp") + + // Write the metadata. + b, err := json.Marshal(metadata) + if err != nil { + return nil, err + } + + if len(b) > maxMetadataSize { + return nil, ErrInvalidMetadataLength + } + + // Metadata length. + if err := binary.WriteUint64(mw, binary.BigEndian, uint64(len(b))); err != nil { + return nil, err + } + // Metadata bytes; io.MultiWriter will return a short write error if + // any of the writers returns < n. + if _, err := mw.Write(b); err != nil { + return nil, err + } + // Write the current hash. + cur := h.Sum(nil) + for done := 0; done < len(cur); { + n, err := mw.Write(cur[done:]) + done += n + if err != nil { + return nil, err + } + } + + w = hashio.NewWriter(w, h) + + // Wrap in compression. + return compressio.NewWriter(w, compressionChunkSize, compressionLevel) +} + +// MetadataUnsafe reads out the metadata from a state file without verifying any +// HMAC. This function shouldn't be called for untrusted input files. +func MetadataUnsafe(r io.Reader) (map[string]string, error) { + return metadata(r, nil) +} + +// metadata validates the magic header and reads out the metadata from a state +// data stream. +func metadata(r io.Reader, h hash.Hash) (map[string]string, error) { + if h != nil { + r = io.TeeReader(r, h) + } + + // Read and validate magic header. + b := make([]byte, len(magicHeader)) + if _, err := r.Read(b); err != nil { + return nil, err + } + if !bytes.Equal(b, magicHeader) { + return nil, ErrBadMagic + } + + // Read and validate metadata. + b, err := func() (b []byte, err error) { + defer func() { + if r := recover(); r != nil { + b = nil + err = fmt.Errorf("%v", r) + } + }() + + metadataLen, err := binary.ReadUint64(r, binary.BigEndian) + if err != nil { + return nil, err + } + if metadataLen > maxMetadataSize { + return nil, ErrInvalidMetadataLength + } + b = make([]byte, int(metadataLen)) + if _, err := io.ReadFull(r, b); err != nil { + return nil, err + } + return b, nil + }() + if err != nil { + return nil, err + } + + if h != nil { + // Check the hash prior to decoding. + cur := h.Sum(nil) + buf := make([]byte, len(cur)) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + if !hmac.Equal(cur, buf) { + return nil, hashio.ErrHashMismatch + } + } + + // Decode the metadata. + metadata := make(map[string]string) + if err := json.Unmarshal(b, &metadata); err != nil { + return nil, err + } + + return metadata, nil +} + +// NewReader returns a reader for a statefile. +func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) { + // Read the metadata with the hash. + h := hmac.New(sha256.New, key) + metadata, err := metadata(r, h) + if err != nil { + return nil, nil, err + } + + r = hashio.NewReader(r, h) + + // Wrap in compression. + rc, err := compressio.NewReader(r) + if err != nil { + return nil, nil, err + } + return rc, metadata, nil +} diff --git a/pkg/state/statefile/statefile_test.go b/pkg/state/statefile/statefile_test.go new file mode 100644 index 000000000..6e67b51de --- /dev/null +++ b/pkg/state/statefile/statefile_test.go @@ -0,0 +1,299 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package statefile + +import ( + "bytes" + "compress/flate" + crand "crypto/rand" + "encoding/base64" + "io" + "math/rand" + "testing" + + "gvisor.googlesource.com/gvisor/pkg/hashio" +) + +func randomKey() ([]byte, error) { + r := make([]byte, base64.RawStdEncoding.DecodedLen(keySize)) + if _, err := io.ReadFull(crand.Reader, r); err != nil { + return nil, err + } + key := make([]byte, keySize) + base64.RawStdEncoding.Encode(key, r) + return key, nil +} + +type testCase struct { + name string + data []byte + metadata map[string]string +} + +func TestStatefile(t *testing.T) { + cases := []testCase{ + // Various data sizes. + {"nil", nil, nil}, + {"empty", []byte(""), nil}, + {"some", []byte("_"), nil}, + {"one", []byte("0"), nil}, + {"two", []byte("01"), nil}, + {"three", []byte("012"), nil}, + {"four", []byte("0123"), nil}, + {"five", []byte("01234"), nil}, + {"six", []byte("012356"), nil}, + {"seven", []byte("0123567"), nil}, + {"eight", []byte("01235678"), nil}, + + // Make sure we have one longer than the hash length. + {"longer than hash", []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"), nil}, + + // Make sure we have one longer than the segment size. + {"segments", make([]byte, 3*hashio.SegmentSize), nil}, + {"segments minus one", make([]byte, 3*hashio.SegmentSize-1), nil}, + {"segments plus one", make([]byte, 3*hashio.SegmentSize+1), nil}, + {"segments minus hash", make([]byte, 3*hashio.SegmentSize-32), nil}, + {"segments plus hash", make([]byte, 3*hashio.SegmentSize+32), nil}, + {"large", make([]byte, 30*hashio.SegmentSize), nil}, + + // Different metadata. + {"one metadata", []byte("data"), map[string]string{"foo": "bar"}}, + {"two metadata", []byte("data"), map[string]string{"foo": "bar", "one": "two"}}, + } + + for _, c := range cases { + // Generate a key. + integrityKey, err := randomKey() + if err != nil { + t.Errorf("can't generate key: got %v, excepted nil", err) + continue + } + + t.Run(c.name, func(t *testing.T) { + for _, key := range [][]byte{nil, integrityKey} { + t.Run("key="+string(key), func(t *testing.T) { + // Encoding happens via a buffer. + var bufEncoded bytes.Buffer + var bufDecoded bytes.Buffer + + // Do all the writing. + w, err := NewWriter(&bufEncoded, key, c.metadata, flate.BestSpeed) + if err != nil { + t.Fatalf("error creating writer: got %v, expected nil", err) + } + if _, err := io.Copy(w, bytes.NewBuffer(c.data)); err != nil { + t.Fatalf("error during write: got %v, expected nil", err) + } + + // Finish the sum. + if err := w.Close(); err != nil { + t.Fatalf("error during close: got %v, expected nil", err) + } + + t.Logf("original data: %d bytes, encoded: %d bytes.", + len(c.data), len(bufEncoded.Bytes())) + + // Do all the reading. + r, metadata, err := NewReader(bytes.NewReader(bufEncoded.Bytes()), key) + if err != nil { + t.Fatalf("error creating reader: got %v, expected nil", err) + } + if _, err := io.Copy(&bufDecoded, r); err != nil { + t.Fatalf("error during read: got %v, expected nil", err) + } + + // Check that the data matches. + if !bytes.Equal(c.data, bufDecoded.Bytes()) { + t.Fatalf("data didn't match (%d vs %d bytes)", len(bufDecoded.Bytes()), len(c.data)) + } + + // Check that the metadata matches. + for k, v := range c.metadata { + nv, ok := metadata[k] + if !ok { + t.Fatalf("missing metadata: %s", k) + } + if v != nv { + t.Fatalf("mismatched metdata for %s: got %s, expected %s", k, nv, v) + } + } + + // Change the data and verify that it fails. + b := append([]byte(nil), bufEncoded.Bytes()...) + b[rand.Intn(len(b))]++ + r, _, err = NewReader(bytes.NewReader(b), key) + if err == nil { + _, err = io.Copy(&bufDecoded, r) + } + if err == nil { + t.Error("got no error: expected error on data corruption") + } + + // Change the key and verify that it fails. + if key == nil { + key = integrityKey + } else { + key[rand.Intn(len(key))]++ + } + r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), key) + if err == nil { + _, err = io.Copy(&bufDecoded, r) + } + if err != hashio.ErrHashMismatch { + t.Errorf("got error: %v, expected ErrHashMismatch on key mismatch", err) + } + }) + } + }) + } +} + +const benchmarkDataSize = 10 * 1024 * 1024 + +func benchmark(b *testing.B, size int, write bool, compressible bool) { + b.StopTimer() + b.SetBytes(benchmarkDataSize) + + // Generate source data. + var source []byte + if compressible { + // For compressible data, we use essentially all zeros. + source = make([]byte, benchmarkDataSize) + } else { + // For non-compressible data, we use random base64 data (to + // make it marginally compressible, a ratio of 75%). + var sourceBuf bytes.Buffer + bufW := base64.NewEncoder(base64.RawStdEncoding, &sourceBuf) + bufR := rand.New(rand.NewSource(0)) + if _, err := io.CopyN(bufW, bufR, benchmarkDataSize); err != nil { + b.Fatalf("unable to seed random data: %v", err) + } + source = sourceBuf.Bytes() + } + + // Generate a random key for integrity check. + key, err := randomKey() + if err != nil { + b.Fatalf("error generating key: %v", err) + } + + // Define our benchmark functions. Prior to running the readState + // function here, you must execute the writeState function at least + // once (done below). + var stateBuf bytes.Buffer + writeState := func() { + stateBuf.Reset() + w, err := NewWriter(&stateBuf, key, nil, flate.BestSpeed) + if err != nil { + b.Fatalf("error creating writer: %v", err) + } + for done := 0; done < len(source); { + chunk := size // limit size. + if done+chunk > len(source) { + chunk = len(source) - done + } + n, err := w.Write(source[done : done+chunk]) + done += n + if n == 0 && err != nil { + b.Fatalf("error during write: %v", err) + } + } + if err := w.Close(); err != nil { + b.Fatalf("error closing writer: %v", err) + } + } + readState := func() { + tmpBuf := bytes.NewBuffer(stateBuf.Bytes()) + r, _, err := NewReader(tmpBuf, key) + if err != nil { + b.Fatalf("error creating reader: %v", err) + } + for done := 0; done < len(source); { + chunk := size // limit size. + if done+chunk > len(source) { + chunk = len(source) - done + } + n, err := r.Read(source[done : done+chunk]) + done += n + if n == 0 && err != nil { + b.Fatalf("error during read: %v", err) + } + } + } + // Generate the state once without timing to ensure that buffers have + // been appropriately allocated. + writeState() + if write { + b.StartTimer() + for i := 0; i < b.N; i++ { + writeState() + } + b.StopTimer() + } else { + b.StartTimer() + for i := 0; i < b.N; i++ { + readState() + } + b.StopTimer() + } +} + +func BenchmarkWrite1BCompressible(b *testing.B) { + benchmark(b, 1, true, true) +} + +func BenchmarkWrite1BNoncompressible(b *testing.B) { + benchmark(b, 1, true, false) +} + +func BenchmarkWrite4KCompressible(b *testing.B) { + benchmark(b, 4096, true, true) +} + +func BenchmarkWrite4KNoncompressible(b *testing.B) { + benchmark(b, 4096, true, false) +} + +func BenchmarkWrite1MCompressible(b *testing.B) { + benchmark(b, 1024*1024, true, true) +} + +func BenchmarkWrite1MNoncompressible(b *testing.B) { + benchmark(b, 1024*1024, true, false) +} + +func BenchmarkRead1BCompressible(b *testing.B) { + benchmark(b, 1, false, true) +} + +func BenchmarkRead1BNoncompressible(b *testing.B) { + benchmark(b, 1, false, false) +} + +func BenchmarkRead4KCompressible(b *testing.B) { + benchmark(b, 4096, false, true) +} + +func BenchmarkRead4KNoncompressible(b *testing.B) { + benchmark(b, 4096, false, false) +} + +func BenchmarkRead1MCompressible(b *testing.B) { + benchmark(b, 1024*1024, false, true) +} + +func BenchmarkRead1MNoncompressible(b *testing.B) { + benchmark(b, 1024*1024, false, false) +} diff --git a/pkg/state/stats.go b/pkg/state/stats.go new file mode 100644 index 000000000..1ebd8ebb4 --- /dev/null +++ b/pkg/state/stats.go @@ -0,0 +1,133 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "bytes" + "fmt" + "reflect" + "sort" + "time" +) + +type statEntry struct { + count uint + total time.Duration +} + +// Stats tracks encode / decode timing. +// +// This currently provides a meaningful String function and no other way to +// extract stats about individual types. +// +// All exported receivers accept nil. +type Stats struct { + // byType contains a breakdown of time spent by type. + byType map[reflect.Type]*statEntry + + // stack contains objects in progress. + stack []reflect.Type + + // last is the last start time. + last time.Time +} + +// sample adds the given number of samples to the given object. +func (s *Stats) sample(typ reflect.Type, count uint) { + if s.byType == nil { + s.byType = make(map[reflect.Type]*statEntry) + } + entry, ok := s.byType[typ] + if !ok { + entry = new(statEntry) + s.byType[typ] = entry + } + now := time.Now() + entry.count += count + entry.total += now.Sub(s.last) + s.last = now +} + +// Start starts a sample. +func (s *Stats) Start(obj reflect.Value) { + if s == nil { + return + } + if len(s.stack) > 0 { + last := s.stack[len(s.stack)-1] + s.sample(last, 0) + } else { + // First time sample. + s.last = time.Now() + } + s.stack = append(s.stack, obj.Type()) +} + +// Done finishes the current sample. +func (s *Stats) Done() { + if s == nil { + return + } + last := s.stack[len(s.stack)-1] + s.sample(last, 1) + s.stack = s.stack[:len(s.stack)-1] +} + +type sliceEntry struct { + typ reflect.Type + entry *statEntry +} + +// String returns a table representation of the stats. +func (s *Stats) String() string { + if s == nil || len(s.byType) == 0 { + return "(no data)" + } + + // Build a list of stat entries. + ss := make([]sliceEntry, 0, len(s.byType)) + for typ, entry := range s.byType { + ss = append(ss, sliceEntry{ + typ: typ, + entry: entry, + }) + } + + // Sort by total time (descending). + sort.Slice(ss, func(i, j int) bool { + return ss[i].entry.total > ss[j].entry.total + }) + + // Print the stat results. + var ( + buf bytes.Buffer + count uint + total time.Duration + ) + buf.WriteString("\n") + buf.WriteString(fmt.Sprintf("%12s | %8s | %8s | %s\n", "total", "count", "per", "type")) + buf.WriteString("-------------+----------+----------+-------------\n") + for _, se := range ss { + count += se.entry.count + total += se.entry.total + per := se.entry.total / time.Duration(se.entry.count) + buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | %s\n", + se.entry.total, se.entry.count, per, se.typ.String())) + } + buf.WriteString("-------------+----------+----------+-------------\n") + buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | [all]", + total, count, total/time.Duration(count))) + return string(buf.Bytes()) +} |