summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAdin Scannell <ascannell@google.com>2019-10-31 18:02:04 -0700
committergVisor bot <gvisor-bot@google.com>2019-10-31 18:03:24 -0700
commita99d3479a84ca86843e500dbdf58db0af389b536 (patch)
tree4bce14ea740020ee2652a6a874517550174d1f34
parent36837c4ad3f3c840791379db81d02b60d918c0f5 (diff)
Add context to state.
PiperOrigin-RevId: 277840416
-rw-r--r--pkg/sentry/context/context.go63
-rw-r--r--pkg/sentry/kernel/context.go32
-rw-r--r--pkg/sentry/kernel/kernel.go13
-rw-r--r--pkg/sentry/pgalloc/save_restore.go13
-rw-r--r--pkg/state/decode.go4
-rw-r--r--pkg/state/encode.go4
-rw-r--r--pkg/state/map.go11
-rw-r--r--pkg/state/state.go7
-rw-r--r--pkg/state/state_test.go11
9 files changed, 115 insertions, 43 deletions
diff --git a/pkg/sentry/context/context.go b/pkg/sentry/context/context.go
index dfd62cbdb..23e009ef3 100644
--- a/pkg/sentry/context/context.go
+++ b/pkg/sentry/context/context.go
@@ -12,10 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package context defines the sentry's Context type.
+// Package context defines an internal context type.
+//
+// The given Context conforms to the standard Go context, but mandates
+// additional methods that are specific to the kernel internals. Note however,
+// that the Context described by this package carries additional constraints
+// regarding concurrent access and retaining beyond the scope of a call.
+//
+// See the Context type for complete details.
package context
import (
+ "context"
+ "time"
+
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/log"
)
@@ -59,6 +69,7 @@ func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) {
type Context interface {
log.Logger
amutex.Sleeper
+ context.Context
// UninterruptibleSleepStart indicates the beginning of an uninterruptible
// sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate
@@ -72,19 +83,36 @@ type Context interface {
// AddressSpace is activated. Normally activate is the same value as the
// deactivate parameter passed to UninterruptibleSleepStart.
UninterruptibleSleepFinish(activate bool)
+}
+
+// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep
+// methods for anonymous embedding in other types that do not implement sleeps.
+type NoopSleeper struct {
+ amutex.NoopSleeper
+}
+
+// UninterruptibleSleepStart does nothing.
+func (NoopSleeper) UninterruptibleSleepStart(bool) {}
+
+// UninterruptibleSleepFinish does nothing.
+func (NoopSleeper) UninterruptibleSleepFinish(bool) {}
+
+// Deadline returns zero values, meaning no deadline.
+func (NoopSleeper) Deadline() (time.Time, bool) {
+ return time.Time{}, false
+}
+
+// Done returns nil.
+func (NoopSleeper) Done() <-chan struct{} {
+ return nil
+}
- // Value returns the value associated with this Context for key, or nil if
- // no value is associated with key. Successive calls to Value with the same
- // key returns the same result.
- //
- // A key identifies a specific value in a Context. Functions that wish to
- // retrieve values from Context typically allocate a key in a global
- // variable then use that key as the argument to Context.Value. A key can
- // be any type that supports equality; packages should define keys as an
- // unexported type to avoid collisions.
- Value(key interface{}) interface{}
+// Err returns nil.
+func (NoopSleeper) Err() error {
+ return nil
}
+// logContext implements basic logging.
type logContext struct {
log.Logger
NoopSleeper
@@ -95,19 +123,6 @@ func (logContext) Value(key interface{}) interface{} {
return nil
}
-// NoopSleeper is a noop implementation of amutex.Sleeper and
-// Context.UninterruptibleSleep* methods for anonymous embedding in other types
-// that do not want to notify kernel.Task about sleeps.
-type NoopSleeper struct {
- amutex.NoopSleeper
-}
-
-// UninterruptibleSleepStart does nothing.
-func (NoopSleeper) UninterruptibleSleepStart(bool) {}
-
-// UninterruptibleSleepFinish does nothing.
-func (NoopSleeper) UninterruptibleSleepFinish(bool) {}
-
// bgContext is the context returned by context.Background.
var bgContext = &logContext{Logger: log.Log()}
diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go
index e3f5b0d83..3c9dceaba 100644
--- a/pkg/sentry/kernel/context.go
+++ b/pkg/sentry/kernel/context.go
@@ -15,6 +15,8 @@
package kernel
import (
+ "time"
+
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
)
@@ -97,6 +99,21 @@ func TaskFromContext(ctx context.Context) *Task {
return nil
}
+// Deadline implements context.Context.Deadline.
+func (*Task) Deadline() (time.Time, bool) {
+ return time.Time{}, false
+}
+
+// Done implements context.Context.Done.
+func (*Task) Done() <-chan struct{} {
+ return nil
+}
+
+// Err implements context.Context.Err.
+func (*Task) Err() error {
+ return nil
+}
+
// AsyncContext returns a context.Context that may be used by goroutines that
// do work on behalf of t and therefore share its contextual values, but are
// not t's task goroutine (e.g. asynchronous I/O).
@@ -129,6 +146,21 @@ func (ctx taskAsyncContext) IsLogging(level log.Level) bool {
return ctx.t.IsLogging(level)
}
+// Deadline implements context.Context.Deadline.
+func (ctx taskAsyncContext) Deadline() (time.Time, bool) {
+ return ctx.t.Deadline()
+}
+
+// Done implements context.Context.Done.
+func (ctx taskAsyncContext) Done() <-chan struct{} {
+ return ctx.t.Done()
+}
+
+// Err implements context.Context.Err.
+func (ctx taskAsyncContext) Err() error {
+ return ctx.t.Err()
+}
+
// Value implements context.Context.Value.
func (ctx taskAsyncContext) Value(key interface{}) interface{} {
return ctx.t.Value(key)
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index e64d648e2..28ba950bd 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -391,7 +391,7 @@ func (k *Kernel) SaveTo(w io.Writer) error {
//
// N.B. This will also be saved along with the full kernel save below.
cpuidStart := time.Now()
- if err := state.Save(w, k.FeatureSet(), nil); err != nil {
+ if err := state.Save(k.SupervisorContext(), w, k.FeatureSet(), nil); err != nil {
return err
}
log.Infof("CPUID save took [%s].", time.Since(cpuidStart))
@@ -399,7 +399,7 @@ func (k *Kernel) SaveTo(w io.Writer) error {
// Save the kernel state.
kernelStart := time.Now()
var stats state.Stats
- if err := state.Save(w, k, &stats); err != nil {
+ if err := state.Save(k.SupervisorContext(), w, k, &stats); err != nil {
return err
}
log.Infof("Kernel save stats: %s", &stats)
@@ -407,7 +407,7 @@ func (k *Kernel) SaveTo(w io.Writer) error {
// Save the memory file's state.
memoryStart := time.Now()
- if err := k.mf.SaveTo(w); err != nil {
+ if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil {
return err
}
log.Infof("Memory save took [%s].", time.Since(memoryStart))
@@ -542,7 +542,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
// don't need to explicitly install it in the Kernel.
cpuidStart := time.Now()
var features cpuid.FeatureSet
- if err := state.Load(r, &features, nil); err != nil {
+ if err := state.Load(k.SupervisorContext(), r, &features, nil); err != nil {
return err
}
log.Infof("CPUID load took [%s].", time.Since(cpuidStart))
@@ -558,7 +558,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
// Load the kernel state.
kernelStart := time.Now()
var stats state.Stats
- if err := state.Load(r, k, &stats); err != nil {
+ if err := state.Load(k.SupervisorContext(), r, k, &stats); err != nil {
return err
}
log.Infof("Kernel load stats: %s", &stats)
@@ -566,7 +566,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
// Load the memory file's state.
memoryStart := time.Now()
- if err := k.mf.LoadFrom(r); err != nil {
+ if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil {
return err
}
log.Infof("Memory load took [%s].", time.Since(memoryStart))
@@ -1322,6 +1322,7 @@ func (k *Kernel) ListSockets() []*SocketEntry {
return socks
}
+// supervisorContext is a privileged context.
type supervisorContext struct {
context.NoopSleeper
log.Logger
diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go
index 1effc7735..aafce1d00 100644
--- a/pkg/sentry/pgalloc/save_restore.go
+++ b/pkg/sentry/pgalloc/save_restore.go
@@ -16,6 +16,7 @@ package pgalloc
import (
"bytes"
+ "context"
"fmt"
"io"
"runtime"
@@ -29,7 +30,7 @@ import (
)
// SaveTo writes f's state to the given stream.
-func (f *MemoryFile) SaveTo(w io.Writer) error {
+func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error {
// Wait for reclaim.
f.mu.Lock()
defer f.mu.Unlock()
@@ -78,10 +79,10 @@ func (f *MemoryFile) SaveTo(w io.Writer) error {
}
// Save metadata.
- if err := state.Save(w, &f.fileSize, nil); err != nil {
+ if err := state.Save(ctx, w, &f.fileSize, nil); err != nil {
return err
}
- if err := state.Save(w, &f.usage, nil); err != nil {
+ if err := state.Save(ctx, w, &f.usage, nil); err != nil {
return err
}
@@ -114,9 +115,9 @@ func (f *MemoryFile) SaveTo(w io.Writer) error {
}
// LoadFrom loads MemoryFile state from the given stream.
-func (f *MemoryFile) LoadFrom(r io.Reader) error {
+func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error {
// Load metadata.
- if err := state.Load(r, &f.fileSize, nil); err != nil {
+ if err := state.Load(ctx, r, &f.fileSize, nil); err != nil {
return err
}
if err := f.file.Truncate(f.fileSize); err != nil {
@@ -124,7 +125,7 @@ func (f *MemoryFile) LoadFrom(r io.Reader) error {
}
newMappings := make([]uintptr, f.fileSize>>chunkShift)
f.mappings.Store(newMappings)
- if err := state.Load(r, &f.usage, nil); err != nil {
+ if err := state.Load(ctx, r, &f.usage, nil); err != nil {
return err
}
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
index 47e6b878a..590c241a3 100644
--- a/pkg/state/decode.go
+++ b/pkg/state/decode.go
@@ -16,6 +16,7 @@ package state
import (
"bytes"
+ "context"
"encoding/binary"
"errors"
"fmt"
@@ -133,6 +134,9 @@ func (os *objectState) findCycle() []*objectState {
// to ensure that all callbacks are executed, otherwise the callback graph was
// not acyclic.
type decodeState struct {
+ // ctx is the decode context.
+ ctx context.Context
+
// objectByID is the set of objects in progress.
objectsByID map[uint64]*objectState
diff --git a/pkg/state/encode.go b/pkg/state/encode.go
index 5d9409a45..c5118d3a9 100644
--- a/pkg/state/encode.go
+++ b/pkg/state/encode.go
@@ -16,6 +16,7 @@ package state
import (
"container/list"
+ "context"
"encoding/binary"
"fmt"
"io"
@@ -38,6 +39,9 @@ type queuedObject struct {
// The encoding process is a breadth-first traversal of the object graph. The
// inherent races and dependencies are much simpler than the decode case.
type encodeState struct {
+ // ctx is the encode context.
+ ctx context.Context
+
// lastID is the last object ID.
//
// See idsByObject for context. Because of the special zero encoding
diff --git a/pkg/state/map.go b/pkg/state/map.go
index 7e6fefed4..4f3ebb0da 100644
--- a/pkg/state/map.go
+++ b/pkg/state/map.go
@@ -15,6 +15,7 @@
package state
import (
+ "context"
"fmt"
"reflect"
"sort"
@@ -219,3 +220,13 @@ func (m Map) AfterLoad(fn func()) {
// data dependencies have been cleared.
m.os.callbacks = append(m.os.callbacks, fn)
}
+
+// Context returns the current context object.
+func (m Map) Context() context.Context {
+ if m.es != nil {
+ return m.es.ctx
+ } else if m.ds != nil {
+ return m.ds.ctx
+ }
+ return context.Background() // No context.
+}
diff --git a/pkg/state/state.go b/pkg/state/state.go
index d408ff84a..dbe507ab4 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -50,6 +50,7 @@
package state
import (
+ "context"
"fmt"
"io"
"reflect"
@@ -86,9 +87,10 @@ func UnwrapErrState(err error) error {
}
// Save saves the given object state.
-func Save(w io.Writer, rootPtr interface{}, stats *Stats) error {
+func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error {
// Create the encoding state.
es := &encodeState{
+ ctx: ctx,
idsByObject: make(map[uintptr]uint64),
w: w,
stats: stats,
@@ -101,9 +103,10 @@ func Save(w io.Writer, rootPtr interface{}, stats *Stats) error {
}
// Load loads a checkpoint.
-func Load(r io.Reader, rootPtr interface{}, stats *Stats) error {
+func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error {
// Create the decoding state.
ds := &decodeState{
+ ctx: ctx,
objectsByID: make(map[uint64]*objectState),
deferred: make(map[uint64]*pb.Object),
r: r,
diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go
index 7c24bbcda..d7221e9e8 100644
--- a/pkg/state/state_test.go
+++ b/pkg/state/state_test.go
@@ -16,6 +16,7 @@ package state
import (
"bytes"
+ "context"
"io/ioutil"
"math"
"reflect"
@@ -46,7 +47,7 @@ func runTest(t *testing.T, tests []TestCase) {
saveBuffer := &bytes.Buffer{}
saveObjectPtr := reflect.New(reflect.TypeOf(root))
saveObjectPtr.Elem().Set(reflect.ValueOf(root))
- if err := Save(saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail {
+ if err := Save(context.Background(), saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail {
t.Errorf(" FAIL: Save failed unexpectedly: %v", err)
continue
} else if err != nil {
@@ -56,7 +57,7 @@ func runTest(t *testing.T, tests []TestCase) {
// Load a new copy of the object.
loadObjectPtr := reflect.New(reflect.TypeOf(root))
- if err := Load(bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail {
+ if err := Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail {
t.Errorf(" FAIL: Load failed unexpectedly: %v", err)
continue
} else if err != nil {
@@ -624,7 +625,7 @@ func BenchmarkEncoding(b *testing.B) {
bs := buildObject(b.N)
var stats Stats
b.StartTimer()
- if err := Save(ioutil.Discard, bs, &stats); err != nil {
+ if err := Save(context.Background(), ioutil.Discard, bs, &stats); err != nil {
b.Errorf("save failed: %v", err)
}
b.StopTimer()
@@ -638,12 +639,12 @@ func BenchmarkDecoding(b *testing.B) {
bs := buildObject(b.N)
var newBS benchStruct
buf := &bytes.Buffer{}
- if err := Save(buf, bs, nil); err != nil {
+ if err := Save(context.Background(), buf, bs, nil); err != nil {
b.Errorf("save failed: %v", err)
}
var stats Stats
b.StartTimer()
- if err := Load(buf, &newBS, &stats); err != nil {
+ if err := Load(context.Background(), buf, &newBS, &stats); err != nil {
b.Errorf("load failed: %v", err)
}
b.StopTimer()