From 1c0535297067179a822ba2dd9a6fe13a8be5a666 Mon Sep 17 00:00:00 2001
From: Jamie Liu <jamieliu@google.com>
Date: Fri, 13 Mar 2020 13:17:59 -0700
Subject: Fix oom_score_adj.

- Make oomScoreAdj a ThreadGroup field (Linux: signal_struct::oom_score_adj).

- Avoid deadlock caused by Task.OOMScoreAdj()/SetOOMScoreAdj() locking Task.mu
  and TaskSet.mu in the wrong order (via Task.ExitState()).

PiperOrigin-RevId: 300814698
---
 pkg/sentry/kernel/task.go         | 29 ++++++-----------------------
 pkg/sentry/kernel/task_clone.go   |  9 +++------
 pkg/sentry/kernel/task_start.go   |  4 ----
 pkg/sentry/kernel/thread_group.go |  7 +++++++
 4 files changed, 16 insertions(+), 33 deletions(-)

(limited to 'pkg/sentry/kernel')

diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index c0dbbe890..8452ddf5b 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -555,13 +555,6 @@ type Task struct {
 	//
 	// startTime is protected by mu.
 	startTime ktime.Time
-
-	// oomScoreAdj is the task's OOM score adjustment. This is currently not
-	// used but is maintained for consistency.
-	// TODO(gvisor.dev/issue/1967)
-	//
-	// oomScoreAdj is protected by mu, and is owned by the task goroutine.
-	oomScoreAdj int32
 }
 
 func (t *Task) savePtraceTracer() *Task {
@@ -856,27 +849,17 @@ func (t *Task) ContainerID() string {
 	return t.containerID
 }
 
-// OOMScoreAdj gets the task's OOM score adjustment.
-func (t *Task) OOMScoreAdj() (int32, error) {
-	t.mu.Lock()
-	defer t.mu.Unlock()
-	if t.ExitState() == TaskExitDead {
-		return 0, syserror.ESRCH
-	}
-	return t.oomScoreAdj, nil
+// OOMScoreAdj gets the task's thread group's OOM score adjustment.
+func (t *Task) OOMScoreAdj() int32 {
+	return atomic.LoadInt32(&t.tg.oomScoreAdj)
 }
 
-// SetOOMScoreAdj sets the task's OOM score adjustment. The value should be
-// between -1000 and 1000 inclusive.
+// SetOOMScoreAdj sets the task's thread group's OOM score adjustment. The
+// value should be between -1000 and 1000 inclusive.
 func (t *Task) SetOOMScoreAdj(adj int32) error {
-	t.mu.Lock()
-	defer t.mu.Unlock()
-	if t.ExitState() == TaskExitDead {
-		return syserror.ESRCH
-	}
 	if adj > 1000 || adj < -1000 {
 		return syserror.EINVAL
 	}
-	t.oomScoreAdj = adj
+	atomic.StoreInt32(&t.tg.oomScoreAdj, adj)
 	return nil
 }
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index dda502bb8..e1ecca99e 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -15,6 +15,8 @@
 package kernel
 
 import (
+	"sync/atomic"
+
 	"gvisor.dev/gvisor/pkg/abi/linux"
 	"gvisor.dev/gvisor/pkg/bpf"
 	"gvisor.dev/gvisor/pkg/sentry/inet"
@@ -260,15 +262,11 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
 			sh = sh.Fork()
 		}
 		tg = t.k.NewThreadGroup(tg.mounts, pidns, sh, opts.TerminationSignal, tg.limits.GetCopy())
+		tg.oomScoreAdj = atomic.LoadInt32(&t.tg.oomScoreAdj)
 		rseqAddr = t.rseqAddr
 		rseqSignature = t.rseqSignature
 	}
 
-	adj, err := t.OOMScoreAdj()
-	if err != nil {
-		return 0, nil, err
-	}
-
 	cfg := &TaskConfig{
 		Kernel:                  t.k,
 		ThreadGroup:             tg,
@@ -287,7 +285,6 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
 		RSeqAddr:                rseqAddr,
 		RSeqSignature:           rseqSignature,
 		ContainerID:             t.ContainerID(),
-		OOMScoreAdj:             adj,
 	}
 	if opts.NewThreadGroup {
 		cfg.Parent = t
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index 2bbf48bb8..a5035bb7f 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -93,9 +93,6 @@ type TaskConfig struct {
 
 	// ContainerID is the container the new task belongs to.
 	ContainerID string
-
-	// oomScoreAdj is the task's OOM score adjustment.
-	OOMScoreAdj int32
 }
 
 // NewTask creates a new task defined by cfg.
@@ -146,7 +143,6 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
 		rseqSignature:      cfg.RSeqSignature,
 		futexWaiter:        futex.NewWaiter(),
 		containerID:        cfg.ContainerID,
-		oomScoreAdj:        cfg.OOMScoreAdj,
 	}
 	t.creds.Store(cfg.Credentials)
 	t.endStopCond.L = &t.tg.signalHandlers.mu
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 268f62e9d..52849f5b3 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -254,6 +254,13 @@ type ThreadGroup struct {
 	//
 	// tty is protected by the signal mutex.
 	tty *TTY
+
+	// oomScoreAdj is the thread group's OOM score adjustment. This is
+	// currently not used but is maintained for consistency.
+	// TODO(gvisor.dev/issue/1967)
+	//
+	// oomScoreAdj is accessed using atomic memory operations.
+	oomScoreAdj int32
 }
 
 // NewThreadGroup returns a new, empty thread group in PID namespace pidns. The
-- 
cgit v1.2.3