From a1cb52447f3e9414211b9e0558f1231ae3e59329 Mon Sep 17 00:00:00 2001 From: Jamie Liu Date: Fri, 13 Nov 2020 14:46:03 -0800 Subject: Check for misuse of kernel.Task as context.Context. Checks in Task.block() and Task.Value() are conditional on race detection being enabled, since these functions are relatively hot. Checks in Task.SleepStart() and Task.UninterruptibleSleepStart() are enabled unconditionally, since these functions are not thought to lie on any critical paths, and misuse of these functions is required for b/168241471 to manifest. PiperOrigin-RevId: 342342175 --- pkg/sentry/kernel/BUILD | 2 + pkg/sentry/kernel/aio.go | 50 +--------- pkg/sentry/kernel/context.go | 17 ---- pkg/sentry/kernel/task.go | 69 ++------------ pkg/sentry/kernel/task_block.go | 16 +++- pkg/sentry/kernel/task_context.go | 189 ++++++++++++++++++++++++++++++++++++++ pkg/sentry/kernel/task_run.go | 15 +++ 7 files changed, 227 insertions(+), 131 deletions(-) create mode 100644 pkg/sentry/kernel/task_context.go diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index ff9f0b81c..0ee60569c 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -179,6 +179,7 @@ go_library( "task_acct.go", "task_block.go", "task_clone.go", + "task_context.go", "task_exec.go", "task_exit.go", "task_futex.go", @@ -224,6 +225,7 @@ go_library( "//pkg/cpuid", "//pkg/eventchannel", "//pkg/fspath", + "//pkg/goid", "//pkg/log", "//pkg/marshal", "//pkg/marshal/primitive", diff --git a/pkg/sentry/kernel/aio.go b/pkg/sentry/kernel/aio.go index 0ac78c0b8..ec36d1a49 100644 --- a/pkg/sentry/kernel/aio.go +++ b/pkg/sentry/kernel/aio.go @@ -15,10 +15,7 @@ package kernel import ( - "time" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/log" ) // AIOCallback is an function that does asynchronous I/O on behalf of a task. @@ -26,7 +23,7 @@ type AIOCallback func(context.Context) // QueueAIO queues an AIOCallback which will be run asynchronously. func (t *Task) QueueAIO(cb AIOCallback) { - ctx := taskAsyncContext{t: t} + ctx := t.AsyncContext() wg := &t.TaskSet().aioGoroutines wg.Add(1) go func() { @@ -34,48 +31,3 @@ func (t *Task) QueueAIO(cb AIOCallback) { wg.Done() }() } - -type taskAsyncContext struct { - context.NoopSleeper - t *Task -} - -// Debugf implements log.Logger.Debugf. -func (ctx taskAsyncContext) Debugf(format string, v ...interface{}) { - ctx.t.Debugf(format, v...) -} - -// Infof implements log.Logger.Infof. -func (ctx taskAsyncContext) Infof(format string, v ...interface{}) { - ctx.t.Infof(format, v...) -} - -// Warningf implements log.Logger.Warningf. -func (ctx taskAsyncContext) Warningf(format string, v ...interface{}) { - ctx.t.Warningf(format, v...) -} - -// IsLogging implements log.Logger.IsLogging. -func (ctx taskAsyncContext) IsLogging(level log.Level) bool { - return ctx.t.IsLogging(level) -} - -// Deadline implements context.Context.Deadline. -func (ctx taskAsyncContext) Deadline() (time.Time, bool) { - return ctx.t.Deadline() -} - -// Done implements context.Context.Done. -func (ctx taskAsyncContext) Done() <-chan struct{} { - return ctx.t.Done() -} - -// Err implements context.Context.Err. -func (ctx taskAsyncContext) Err() error { - return ctx.t.Err() -} - -// Value implements context.Context.Value. -func (ctx taskAsyncContext) Value(key interface{}) interface{} { - return ctx.t.Value(key) -} diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go index bb94769c4..a8596410f 100644 --- a/pkg/sentry/kernel/context.go +++ b/pkg/sentry/kernel/context.go @@ -15,8 +15,6 @@ package kernel import ( - "time" - "gvisor.dev/gvisor/pkg/context" ) @@ -98,18 +96,3 @@ func TaskFromContext(ctx context.Context) *Task { } return nil } - -// Deadline implements context.Context.Deadline. -func (*Task) Deadline() (time.Time, bool) { - return time.Time{}, false -} - -// Done implements context.Context.Done. -func (*Task) Done() <-chan struct{} { - return nil -} - -// Err implements context.Context.Err. -func (*Task) Err() error { - return nil -} diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index e830cf948..c0ab53c94 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" - "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/inet" @@ -29,11 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/unimpl" - "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -63,6 +58,12 @@ import ( type Task struct { taskNode + // goid is the task goroutine's ID. goid is owned by the task goroutine, + // but since it's used to detect cases where non-task goroutines + // incorrectly access state owned by, or exclusive to, the task goroutine, + // goid is always accessed using atomic memory operations. + goid int64 `state:"nosave"` + // runState is what the task goroutine is executing if it is not stopped. // If runState is nil, the task goroutine should exit or has exited. // runState is exclusive to the task goroutine. @@ -641,64 +642,6 @@ func (t *Task) Kernel() *Kernel { return t.k } -// Value implements context.Context.Value. -// -// Preconditions: The caller must be running on the task goroutine (as implied -// by the requirements of context.Context). -func (t *Task) Value(key interface{}) interface{} { - switch key { - case CtxCanTrace: - return t.CanTrace - case CtxKernel: - return t.k - case CtxPIDNamespace: - return t.tg.pidns - case CtxUTSNamespace: - return t.utsns - case CtxIPCNamespace: - ipcns := t.IPCNamespace() - ipcns.IncRef() - return ipcns - case CtxTask: - return t - case auth.CtxCredentials: - return t.Credentials() - case context.CtxThreadGroupID: - return int32(t.ThreadGroup().ID()) - case fs.CtxRoot: - return t.fsContext.RootDirectory() - case vfs.CtxRoot: - return t.fsContext.RootDirectoryVFS2() - case vfs.CtxMountNamespace: - t.mountNamespaceVFS2.IncRef() - return t.mountNamespaceVFS2 - case fs.CtxDirentCacheLimiter: - return t.k.DirentCacheLimiter - case inet.CtxStack: - return t.NetworkContext() - case ktime.CtxRealtimeClock: - return t.k.RealtimeClock() - case limits.CtxLimits: - return t.tg.limits - case pgalloc.CtxMemoryFile: - return t.k.mf - case pgalloc.CtxMemoryFileProvider: - return t.k - case platform.CtxPlatform: - return t.k - case uniqueid.CtxGlobalUniqueID: - return t.k.UniqueID() - case uniqueid.CtxGlobalUniqueIDProvider: - return t.k - case uniqueid.CtxInotifyCookie: - return t.k.GenerateInotifyCookie() - case unimpl.CtxEvents: - return t.k - default: - return nil - } -} - // SetClearTID sets t's cleartid. // // Preconditions: The caller must be running on the task goroutine. diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index 4a4a69ee2..147a3d286 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -20,6 +20,7 @@ import ( "time" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -32,6 +33,8 @@ import ( // // - An error which is nil if an event is received from C, ETIMEDOUT if the timeout // expired, and syserror.ErrInterrupted if t is interrupted. +// +// Preconditions: The caller must be running on the task goroutine. func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time.Duration) (time.Duration, error) { if !haveTimeout { return timeout, t.block(C, nil) @@ -112,7 +115,14 @@ func (t *Task) Block(C <-chan struct{}) error { // block blocks a task on one of many events. // N.B. defer is too expensive to be used here. +// +// Preconditions: The caller must be running on the task goroutine. func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error { + // This function is very hot; skip this check outside of +race builds. + if sync.RaceEnabled { + t.assertTaskGoroutine() + } + // Fast path if the request is already done. select { case <-C: @@ -156,14 +166,15 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error { } } -// SleepStart implements amutex.Sleeper.SleepStart. +// SleepStart implements context.ChannelSleeper.SleepStart. func (t *Task) SleepStart() <-chan struct{} { + t.assertTaskGoroutine() t.Deactivate() t.accountTaskGoroutineEnter(TaskGoroutineBlockedInterruptible) return t.interruptChan } -// SleepFinish implements amutex.Sleeper.SleepFinish. +// SleepFinish implements context.ChannelSleeper.SleepFinish. func (t *Task) SleepFinish(success bool) { if !success { // The interrupted notification is consumed only at the top-level @@ -183,6 +194,7 @@ func (t *Task) Interrupted() bool { // UninterruptibleSleepStart implements context.Context.UninterruptibleSleepStart. func (t *Task) UninterruptibleSleepStart(deactivate bool) { + t.assertTaskGoroutine() if deactivate { t.Deactivate() } diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go new file mode 100644 index 000000000..70b0699dc --- /dev/null +++ b/pkg/sentry/kernel/task_context.go @@ -0,0 +1,189 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "time" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/inet" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sentry/pgalloc" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/unimpl" + "gvisor.dev/gvisor/pkg/sentry/uniqueid" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" +) + +// Deadline implements context.Context.Deadline. +func (t *Task) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +// Done implements context.Context.Done. +func (t *Task) Done() <-chan struct{} { + return nil +} + +// Err implements context.Context.Err. +func (t *Task) Err() error { + return nil +} + +// Value implements context.Context.Value. +// +// Preconditions: The caller must be running on the task goroutine. +func (t *Task) Value(key interface{}) interface{} { + // This function is very hot; skip this check outside of +race builds. + if sync.RaceEnabled { + t.assertTaskGoroutine() + } + return t.contextValue(key, true /* isTaskGoroutine */) +} + +func (t *Task) contextValue(key interface{}, isTaskGoroutine bool) interface{} { + switch key { + case CtxCanTrace: + return t.CanTrace + case CtxKernel: + return t.k + case CtxPIDNamespace: + return t.tg.pidns + case CtxUTSNamespace: + if !isTaskGoroutine { + t.mu.Lock() + defer t.mu.Unlock() + } + return t.utsns + case CtxIPCNamespace: + if !isTaskGoroutine { + t.mu.Lock() + defer t.mu.Unlock() + } + ipcns := t.ipcns + ipcns.IncRef() + return ipcns + case CtxTask: + return t + case auth.CtxCredentials: + return t.creds.Load() + case context.CtxThreadGroupID: + return int32(t.tg.ID()) + case fs.CtxRoot: + if !isTaskGoroutine { + t.mu.Lock() + defer t.mu.Unlock() + } + return t.fsContext.RootDirectory() + case vfs.CtxRoot: + if !isTaskGoroutine { + t.mu.Lock() + defer t.mu.Unlock() + } + return t.fsContext.RootDirectoryVFS2() + case vfs.CtxMountNamespace: + if !isTaskGoroutine { + t.mu.Lock() + defer t.mu.Unlock() + } + t.mountNamespaceVFS2.IncRef() + return t.mountNamespaceVFS2 + case fs.CtxDirentCacheLimiter: + return t.k.DirentCacheLimiter + case inet.CtxStack: + return t.NetworkContext() + case ktime.CtxRealtimeClock: + return t.k.RealtimeClock() + case limits.CtxLimits: + return t.tg.limits + case pgalloc.CtxMemoryFile: + return t.k.mf + case pgalloc.CtxMemoryFileProvider: + return t.k + case platform.CtxPlatform: + return t.k + case uniqueid.CtxGlobalUniqueID: + return t.k.UniqueID() + case uniqueid.CtxGlobalUniqueIDProvider: + return t.k + case uniqueid.CtxInotifyCookie: + return t.k.GenerateInotifyCookie() + case unimpl.CtxEvents: + return t.k + default: + return nil + } +} + +// taskAsyncContext implements context.Context for a goroutine that performs +// work on behalf of a Task, but is not the task goroutine. +type taskAsyncContext struct { + context.NoopSleeper + + t *Task +} + +// AsyncContext returns a context.Context representing t. The returned +// context.Context is intended for use by goroutines other than t's task +// goroutine; for example, signal delivery to t will not interrupt goroutines +// that are blocking using the returned context.Context. +func (t *Task) AsyncContext() context.Context { + return taskAsyncContext{t: t} +} + +// Debugf implements log.Logger.Debugf. +func (ctx taskAsyncContext) Debugf(format string, v ...interface{}) { + ctx.t.Debugf(format, v...) +} + +// Infof implements log.Logger.Infof. +func (ctx taskAsyncContext) Infof(format string, v ...interface{}) { + ctx.t.Infof(format, v...) +} + +// Warningf implements log.Logger.Warningf. +func (ctx taskAsyncContext) Warningf(format string, v ...interface{}) { + ctx.t.Warningf(format, v...) +} + +// IsLogging implements log.Logger.IsLogging. +func (ctx taskAsyncContext) IsLogging(level log.Level) bool { + return ctx.t.IsLogging(level) +} + +// Deadline implements context.Context.Deadline. +func (ctx taskAsyncContext) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +// Done implements context.Context.Done. +func (ctx taskAsyncContext) Done() <-chan struct{} { + return nil +} + +// Err implements context.Context.Err. +func (ctx taskAsyncContext) Err() error { + return nil +} + +// Value implements context.Context.Value. +func (ctx taskAsyncContext) Value(key interface{}) interface{} { + return ctx.t.contextValue(key, false /* isTaskGoroutine */) +} diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index 0f8294dcd..c5858da30 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -16,11 +16,13 @@ package kernel import ( "bytes" + "fmt" "runtime" "runtime/trace" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/goid" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/hostcpu" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -57,6 +59,8 @@ type taskRunState interface { // make it visible in stack dumps. A goroutine for a given task can be identified // searching for Task.run()'s argument value. func (t *Task) run(threadID uintptr) { + atomic.StoreInt64(&t.goid, goid.Get()) + // Construct t.blockingTimer here. We do this here because we can't // reconstruct t.blockingTimer during restore in Task.afterLoad(), because // kernel.timekeeper.SetClocks() hasn't been called yet. @@ -99,6 +103,9 @@ func (t *Task) run(threadID uintptr) { t.tg.pidns.owner.runningGoroutines.Done() t.p.Release() + // Deferring this store triggers a false positive in the race + // detector (https://github.com/golang/go/issues/42599). + atomic.StoreInt64(&t.goid, 0) // Keep argument alive because stack trace for dead variables may not be correct. runtime.KeepAlive(threadID) return @@ -375,6 +382,14 @@ func (app *runApp) execute(t *Task) taskRunState { } } +// assertTaskGoroutine panics if the caller is not running on t's task +// goroutine. +func (t *Task) assertTaskGoroutine() { + if got, want := goid.Get(), atomic.LoadInt64(&t.goid); got != want { + panic(fmt.Sprintf("running on goroutine %d (task goroutine for kernel.Task %p is %d)", got, t, want)) + } +} + // waitGoroutineStoppedOrExited blocks until t's task goroutine stops or exits. func (t *Task) waitGoroutineStoppedOrExited() { t.goroutineStopped.Wait() -- cgit v1.2.3