diff options
Diffstat (limited to 'pkg/state/state.go')
-rw-r--r-- | pkg/state/state.go | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/pkg/state/state.go b/pkg/state/state.go index d408ff84a..03ae2dbb0 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -50,6 +50,7 @@ package state import ( + "context" "fmt" "io" "reflect" @@ -86,9 +87,10 @@ func UnwrapErrState(err error) error { } // Save saves the given object state. -func Save(w io.Writer, rootPtr interface{}, stats *Stats) error { +func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error { // Create the encoding state. es := &encodeState{ + ctx: ctx, idsByObject: make(map[uintptr]uint64), w: w, stats: stats, @@ -101,9 +103,10 @@ func Save(w io.Writer, rootPtr interface{}, stats *Stats) error { } // Load loads a checkpoint. -func Load(r io.Reader, rootPtr interface{}, stats *Stats) error { +func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error { // Create the decoding state. ds := &decodeState{ + ctx: ctx, objectsByID: make(map[uint64]*objectState), deferred: make(map[uint64]*pb.Object), r: r, @@ -238,10 +241,7 @@ func Register(name string, instance interface{}, fns Fns) { // // This function is used by the stateify tool. func IsZeroValue(val interface{}) bool { - if val == nil { - return true - } - return reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface()) + return val == nil || reflect.ValueOf(val).Elem().IsZero() } // step captures one encoding / decoding step. On each step, there is up to one |