From a99d3479a84ca86843e500dbdf58db0af389b536 Mon Sep 17 00:00:00 2001
From: Adin Scannell <ascannell@google.com>
Date: Thu, 31 Oct 2019 18:02:04 -0700
Subject: Add context to state.

PiperOrigin-RevId: 277840416
---
 pkg/sentry/kernel/context.go | 32 ++++++++++++++++++++++++++++++++
 pkg/sentry/kernel/kernel.go  | 13 +++++++------
 2 files changed, 39 insertions(+), 6 deletions(-)

(limited to 'pkg/sentry/kernel')

diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go
index e3f5b0d83..3c9dceaba 100644
--- a/pkg/sentry/kernel/context.go
+++ b/pkg/sentry/kernel/context.go
@@ -15,6 +15,8 @@
 package kernel
 
 import (
+	"time"
+
 	"gvisor.dev/gvisor/pkg/log"
 	"gvisor.dev/gvisor/pkg/sentry/context"
 )
@@ -97,6 +99,21 @@ func TaskFromContext(ctx context.Context) *Task {
 	return nil
 }
 
+// Deadline implements context.Context.Deadline.
+func (*Task) Deadline() (time.Time, bool) {
+	return time.Time{}, false
+}
+
+// Done implements context.Context.Done.
+func (*Task) Done() <-chan struct{} {
+	return nil
+}
+
+// Err implements context.Context.Err.
+func (*Task) Err() error {
+	return nil
+}
+
 // AsyncContext returns a context.Context that may be used by goroutines that
 // do work on behalf of t and therefore share its contextual values, but are
 // not t's task goroutine (e.g. asynchronous I/O).
@@ -129,6 +146,21 @@ func (ctx taskAsyncContext) IsLogging(level log.Level) bool {
 	return ctx.t.IsLogging(level)
 }
 
+// Deadline implements context.Context.Deadline.
+func (ctx taskAsyncContext) Deadline() (time.Time, bool) {
+	return ctx.t.Deadline()
+}
+
+// Done implements context.Context.Done.
+func (ctx taskAsyncContext) Done() <-chan struct{} {
+	return ctx.t.Done()
+}
+
+// Err implements context.Context.Err.
+func (ctx taskAsyncContext) Err() error {
+	return ctx.t.Err()
+}
+
 // Value implements context.Context.Value.
 func (ctx taskAsyncContext) Value(key interface{}) interface{} {
 	return ctx.t.Value(key)
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index e64d648e2..28ba950bd 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -391,7 +391,7 @@ func (k *Kernel) SaveTo(w io.Writer) error {
 	//
 	// N.B. This will also be saved along with the full kernel save below.
 	cpuidStart := time.Now()
-	if err := state.Save(w, k.FeatureSet(), nil); err != nil {
+	if err := state.Save(k.SupervisorContext(), w, k.FeatureSet(), nil); err != nil {
 		return err
 	}
 	log.Infof("CPUID save took [%s].", time.Since(cpuidStart))
@@ -399,7 +399,7 @@ func (k *Kernel) SaveTo(w io.Writer) error {
 	// Save the kernel state.
 	kernelStart := time.Now()
 	var stats state.Stats
-	if err := state.Save(w, k, &stats); err != nil {
+	if err := state.Save(k.SupervisorContext(), w, k, &stats); err != nil {
 		return err
 	}
 	log.Infof("Kernel save stats: %s", &stats)
@@ -407,7 +407,7 @@ func (k *Kernel) SaveTo(w io.Writer) error {
 
 	// Save the memory file's state.
 	memoryStart := time.Now()
-	if err := k.mf.SaveTo(w); err != nil {
+	if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil {
 		return err
 	}
 	log.Infof("Memory save took [%s].", time.Since(memoryStart))
@@ -542,7 +542,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
 	// don't need to explicitly install it in the Kernel.
 	cpuidStart := time.Now()
 	var features cpuid.FeatureSet
-	if err := state.Load(r, &features, nil); err != nil {
+	if err := state.Load(k.SupervisorContext(), r, &features, nil); err != nil {
 		return err
 	}
 	log.Infof("CPUID load took [%s].", time.Since(cpuidStart))
@@ -558,7 +558,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
 	// Load the kernel state.
 	kernelStart := time.Now()
 	var stats state.Stats
-	if err := state.Load(r, k, &stats); err != nil {
+	if err := state.Load(k.SupervisorContext(), r, k, &stats); err != nil {
 		return err
 	}
 	log.Infof("Kernel load stats: %s", &stats)
@@ -566,7 +566,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
 
 	// Load the memory file's state.
 	memoryStart := time.Now()
-	if err := k.mf.LoadFrom(r); err != nil {
+	if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil {
 		return err
 	}
 	log.Infof("Memory load took [%s].", time.Since(memoryStart))
@@ -1322,6 +1322,7 @@ func (k *Kernel) ListSockets() []*SocketEntry {
 	return socks
 }
 
+// supervisorContext is a privileged context.
 type supervisorContext struct {
 	context.NoopSleeper
 	log.Logger
-- 
cgit v1.2.3