summaryrefslogtreecommitdiffhomepage
path: root/pkg/state/state.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/state/state.go')
-rw-r--r--pkg/state/state.go360
1 files changed, 161 insertions, 199 deletions
diff --git a/pkg/state/state.go b/pkg/state/state.go
index 03ae2dbb0..acb629969 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -31,210 +31,226 @@
// Uint64 default
// Float32 default
// Float64 default
-// Complex64 custom
-// Complex128 custom
+// Complex64 default
+// Complex128 default
// Array default
// Chan custom
// Func custom
-// Interface custom
-// Map default (*)
+// Interface default
+// Map default
// Ptr default
// Slice default
// String default
-// Struct custom
+// Struct custom (*) Unless zero-sized.
// UnsafePointer custom
//
-// (*) Maps are treated as value types by this package, even if they are
-// pointers internally. If you want to save two independent references
-// to the same map value, you must explicitly use a pointer to a map.
+// See README.md for an overview of how encoding and decoding works.
package state
import (
"context"
"fmt"
- "io"
"reflect"
"runtime"
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
+ "gvisor.dev/gvisor/pkg/state/wire"
)
+// objectID is a unique identifier assigned to each object to be serialized.
+// Each instance of an object is considered separately, i.e. if there are two
+// objects of the same type in the object graph being serialized, they'll be
+// assigned unique objectIDs.
+type objectID uint32
+
+// typeID is the identifier for a type. Types are serialized and tracked
+// alongside objects in order to avoid the overhead of encoding field names in
+// all objects.
+type typeID uint32
+
// ErrState is returned when an error is encountered during encode/decode.
type ErrState struct {
// err is the underlying error.
err error
- // path is the visit path from root to the current object.
- path string
-
// trace is the stack trace.
trace string
}
// Error returns a sensible description of the state error.
func (e *ErrState) Error() string {
- return fmt.Sprintf("%v:\nstate path: %s\n%s", e.err, e.path, e.trace)
+ return fmt.Sprintf("%v:\n%s", e.err, e.trace)
}
-// UnwrapErrState returns the underlying error in ErrState.
-//
-// If err is not *ErrState, err is returned directly.
-func UnwrapErrState(err error) error {
- if e, ok := err.(*ErrState); ok {
- return e.err
- }
- return err
+// Unwrap implements standard unwrapping.
+func (e *ErrState) Unwrap() error {
+ return e.err
}
// Save saves the given object state.
-func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error {
+func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) {
// Create the encoding state.
- es := &encodeState{
- ctx: ctx,
- idsByObject: make(map[uintptr]uint64),
- w: w,
- stats: stats,
+ es := encodeState{
+ ctx: ctx,
+ w: w,
+ types: makeTypeEncodeDatabase(),
+ zeroValues: make(map[reflect.Type]*objectEncodeState),
}
// Perform the encoding.
- return es.safely(func() {
- es.Serialize(reflect.ValueOf(rootPtr).Elem())
+ err := safely(func() {
+ es.Save(reflect.ValueOf(rootPtr).Elem())
})
+ return es.stats, err
}
// Load loads a checkpoint.
-func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error {
+func Load(ctx context.Context, r wire.Reader, rootPtr interface{}) (Stats, error) {
// Create the decoding state.
- ds := &decodeState{
- ctx: ctx,
- objectsByID: make(map[uint64]*objectState),
- deferred: make(map[uint64]*pb.Object),
- r: r,
- stats: stats,
+ ds := decodeState{
+ ctx: ctx,
+ r: r,
+ types: makeTypeDecodeDatabase(),
+ deferred: make(map[objectID]wire.Object),
}
// Attempt our decode.
- return ds.safely(func() {
- ds.Deserialize(reflect.ValueOf(rootPtr).Elem())
+ err := safely(func() {
+ ds.Load(reflect.ValueOf(rootPtr).Elem())
})
+ return ds.stats, err
}
-// Fns are the state dispatch functions.
-type Fns struct {
- // Save is a function like Save(concreteType, Map).
- Save interface{}
-
- // Load is a function like Load(concreteType, Map).
- Load interface{}
+// Sink is used for Type.StateSave.
+type Sink struct {
+ internal objectEncoder
}
-// Save executes the save function.
-func (fns *Fns) invokeSave(obj reflect.Value, m Map) {
- reflect.ValueOf(fns.Save).Call([]reflect.Value{obj, reflect.ValueOf(m)})
+// Save adds the given object to the map.
+//
+// You should pass always pointers to the object you are saving. For example:
+//
+// type X struct {
+// A int
+// B *int
+// }
+//
+// func (x *X) StateTypeInfo(m Sink) state.TypeInfo {
+// return state.TypeInfo{
+// Name: "pkg.X",
+// Fields: []string{
+// "A",
+// "B",
+// },
+// }
+// }
+//
+// func (x *X) StateSave(m Sink) {
+// m.Save(0, &x.A) // Field is A.
+// m.Save(1, &x.B) // Field is B.
+// }
+//
+// func (x *X) StateLoad(m Source) {
+// m.Load(0, &x.A) // Field is A.
+// m.Load(1, &x.B) // Field is B.
+// }
+func (s Sink) Save(slot int, objPtr interface{}) {
+ s.internal.save(slot, reflect.ValueOf(objPtr).Elem())
}
-// Load executes the load function.
-func (fns *Fns) invokeLoad(obj reflect.Value, m Map) {
- reflect.ValueOf(fns.Load).Call([]reflect.Value{obj, reflect.ValueOf(m)})
+// SaveValue adds the given object value to the map.
+//
+// This should be used for values where pointers are not available, or casts
+// are required during Save/Load.
+//
+// For example, if we want to cast external package type P.Foo to int64:
+//
+// func (x *X) StateSave(m Sink) {
+// m.SaveValue(0, "A", int64(x.A))
+// }
+//
+// func (x *X) StateLoad(m Source) {
+// m.LoadValue(0, new(int64), func(x interface{}) {
+// x.A = P.Foo(x.(int64))
+// })
+// }
+func (s Sink) SaveValue(slot int, obj interface{}) {
+ s.internal.save(slot, reflect.ValueOf(obj))
}
-// validateStateFn ensures types are correct.
-func validateStateFn(fn interface{}, typ reflect.Type) bool {
- fnTyp := reflect.TypeOf(fn)
- if fnTyp.Kind() != reflect.Func {
- return false
- }
- if fnTyp.NumIn() != 2 {
- return false
- }
- if fnTyp.NumOut() != 0 {
- return false
- }
- if fnTyp.In(0) != typ {
- return false
- }
- if fnTyp.In(1) != reflect.TypeOf(Map{}) {
- return false
- }
- return true
+// Context returns the context object provided at save time.
+func (s Sink) Context() context.Context {
+ return s.internal.es.ctx
}
-// Validate validates all state functions.
-func (fns *Fns) Validate(typ reflect.Type) bool {
- return validateStateFn(fns.Save, typ) && validateStateFn(fns.Load, typ)
+// Type is an interface that must be implemented by Struct objects. This allows
+// these objects to be serialized while minimizing runtime reflection required.
+//
+// All these methods can be automatically generated by the go_statify tool.
+type Type interface {
+ // StateTypeName returns the type's name.
+ //
+ // This is used for matching type information during encoding and
+ // decoding, as well as dynamic interface dispatch. This should be
+ // globally unique.
+ StateTypeName() string
+
+ // StateFields returns information about the type.
+ //
+ // Fields is the set of fields for the object. Calls to Sink.Save and
+ // Source.Load must be made in-order with respect to these fields.
+ //
+ // This will be called at most once per serialization.
+ StateFields() []string
}
-type typeDatabase struct {
- // nameToType is a forward lookup table.
- nameToType map[string]reflect.Type
-
- // typeToName is the reverse lookup table.
- typeToName map[reflect.Type]string
+// SaverLoader must be implemented by struct types.
+type SaverLoader interface {
+ // StateSave saves the state of the object to the given Map.
+ StateSave(Sink)
- // typeToFns is the function lookup table.
- typeToFns map[reflect.Type]Fns
+ // StateLoad loads the state of the object.
+ StateLoad(Source)
}
-// registeredTypes is a database used for SaveInterface and LoadInterface.
-var registeredTypes = typeDatabase{
- nameToType: make(map[string]reflect.Type),
- typeToName: make(map[reflect.Type]string),
- typeToFns: make(map[reflect.Type]Fns),
+// Source is used for Type.StateLoad.
+type Source struct {
+ internal objectDecoder
}
-// register registers a type under the given name. This will generally be
-// called via init() methods, and therefore uses panic to propagate errors.
-func (t *typeDatabase) register(name string, typ reflect.Type, fns Fns) {
- // We can't allow name collisions.
- if ot, ok := t.nameToType[name]; ok {
- panic(fmt.Sprintf("type %q can't use name %q, already in use by type %q", typ.Name(), name, ot.Name()))
- }
-
- // Or multiple registrations.
- if on, ok := t.typeToName[typ]; ok {
- panic(fmt.Sprintf("type %q can't be registered as %q, already registered as %q", typ.Name(), name, on))
- }
-
- t.nameToType[name] = typ
- t.typeToName[typ] = name
- t.typeToFns[typ] = fns
+// Load loads the given object passed as a pointer..
+//
+// See Sink.Save for an example.
+func (s Source) Load(slot int, objPtr interface{}) {
+ s.internal.load(slot, reflect.ValueOf(objPtr), false, nil)
}
-// lookupType finds a type given a name.
-func (t *typeDatabase) lookupType(name string) (reflect.Type, bool) {
- typ, ok := t.nameToType[name]
- return typ, ok
+// LoadWait loads the given objects from the map, and marks it as requiring all
+// AfterLoad executions to complete prior to running this object's AfterLoad.
+//
+// See Sink.Save for an example.
+func (s Source) LoadWait(slot int, objPtr interface{}) {
+ s.internal.load(slot, reflect.ValueOf(objPtr), true, nil)
}
-// lookupName finds a name given a type.
-func (t *typeDatabase) lookupName(typ reflect.Type) (string, bool) {
- name, ok := t.typeToName[typ]
- return name, ok
+// LoadValue loads the given object value from the map.
+//
+// See Sink.SaveValue for an example.
+func (s Source) LoadValue(slot int, objPtr interface{}, fn func(interface{})) {
+ o := reflect.ValueOf(objPtr)
+ s.internal.load(slot, o, true, func() { fn(o.Elem().Interface()) })
}
-// lookupFns finds functions given a type.
-func (t *typeDatabase) lookupFns(typ reflect.Type) (Fns, bool) {
- fns, ok := t.typeToFns[typ]
- return fns, ok
+// AfterLoad schedules a function execution when all objects have been
+// allocated and their automated loading and customized load logic have been
+// executed. fn will not be executed until all of current object's
+// dependencies' AfterLoad() logic, if exist, have been executed.
+func (s Source) AfterLoad(fn func()) {
+ s.internal.afterLoad(fn)
}
-// Register must be called for any interface implementation types that
-// implements Loader.
-//
-// Register should be called either immediately after startup or via init()
-// methods. Double registration of either names or types will result in a panic.
-//
-// No synchronization is provided; this should only be called in init.
-//
-// Example usage:
-//
-// state.Register("Foo", (*Foo)(nil), state.Fns{
-// Save: (*Foo).Save,
-// Load: (*Foo).Load,
-// })
-//
-func Register(name string, instance interface{}, fns Fns) {
- registeredTypes.register(name, reflect.TypeOf(instance), fns)
+// Context returns the context object provided at load time.
+func (s Source) Context() context.Context {
+ return s.internal.ds.ctx
}
// IsZeroValue checks if the given value is the zero value.
@@ -244,72 +260,14 @@ func IsZeroValue(val interface{}) bool {
return val == nil || reflect.ValueOf(val).Elem().IsZero()
}
-// step captures one encoding / decoding step. On each step, there is up to one
-// choice made, which is captured by non-nil param. We intentionally do not
-// eagerly create the final path string, as that will only be needed upon panic.
-type step struct {
- // dereference indicate if the current object is obtained by
- // dereferencing a pointer.
- dereference bool
-
- // format is the formatting string that takes param below, if
- // non-nil. For example, in array indexing case, we have "[%d]".
- format string
-
- // param stores the choice made at the current encoding / decoding step.
- // For eaxmple, in array indexing case, param stores the index. When no
- // choice is made, e.g. dereference, param should be nil.
- param interface{}
-}
-
-// recoverable is the state encoding / decoding panic recovery facility. It is
-// also used to store encoding / decoding steps as well as the reference to the
-// original queued object from which the current object is dispatched. The
-// complete encoding / decoding path is synthesised from the steps in all queued
-// objects leading to the current object.
-type recoverable struct {
- from *recoverable
- steps []step
+// Failf is a wrapper around panic that should be used to generate errors that
+// can be caught during saving and loading.
+func Failf(fmtStr string, v ...interface{}) {
+ panic(fmt.Errorf(fmtStr, v...))
}
-// push enters a new context level.
-func (sr *recoverable) push(dereference bool, format string, param interface{}) {
- sr.steps = append(sr.steps, step{dereference, format, param})
-}
-
-// pop exits the current context level.
-func (sr *recoverable) pop() {
- if len(sr.steps) <= 1 {
- return
- }
- sr.steps = sr.steps[:len(sr.steps)-1]
-}
-
-// path returns the complete encoding / decoding path from root. This is only
-// called upon panic.
-func (sr *recoverable) path() string {
- if sr.from == nil {
- return "root"
- }
- p := sr.from.path()
- for _, s := range sr.steps {
- if s.dereference {
- p = fmt.Sprintf("*(%s)", p)
- }
- if s.param == nil {
- p += s.format
- } else {
- p += fmt.Sprintf(s.format, s.param)
- }
- }
- return p
-}
-
-func (sr *recoverable) copy() recoverable {
- return recoverable{from: sr.from, steps: append([]step(nil), sr.steps...)}
-}
-
-// safely executes the given function, catching a panic and unpacking as an error.
+// safely executes the given function, catching a panic and unpacking as an
+// error.
//
// The error flow through the state package uses panic and recover. There are
// two important reasons for this:
@@ -323,9 +281,15 @@ func (sr *recoverable) copy() recoverable {
// method doesn't add a lot of value. If there are specific error conditions
// that you'd like to handle, you should add appropriate functionality to
// objects themselves prior to calling Save() and Load().
-func (sr *recoverable) safely(fn func()) (err error) {
+func safely(fn func()) (err error) {
defer func() {
if r := recover(); r != nil {
+ if es, ok := r.(*ErrState); ok {
+ err = es // Propagate.
+ return
+ }
+
+ // Build a new state error.
es := new(ErrState)
if e, ok := r.(error); ok {
es.err = e
@@ -333,8 +297,6 @@ func (sr *recoverable) safely(fn func()) (err error) {
es.err = fmt.Errorf("%v", r)
}
- es.path = sr.path()
-
// Make a stack. We don't know how big it will be ahead
// of time, but want to make sure we get the whole
// thing. So we just do a stupid brute force approach.