diff options
Diffstat (limited to 'pkg/sentry')
-rw-r--r-- | pkg/sentry/kernel/auth/context.go | 13 | ||||
-rw-r--r-- | pkg/sentry/kernel/mq/mq.go | 2 | ||||
-rw-r--r-- | pkg/sentry/kernel/shm/shm.go | 4 | ||||
-rw-r--r-- | pkg/sentry/kernel/task_context.go | 2 |
4 files changed, 17 insertions, 4 deletions
diff --git a/pkg/sentry/kernel/auth/context.go b/pkg/sentry/kernel/auth/context.go index c08d47787..2039a96ad 100644 --- a/pkg/sentry/kernel/auth/context.go +++ b/pkg/sentry/kernel/auth/context.go @@ -24,6 +24,10 @@ type contextID int const ( // CtxCredentials is a Context.Value key for Credentials. CtxCredentials contextID = iota + + // CtxThreadGroupID is the current thread group ID when a context represents + // a task context. The value is represented as an int32. + CtxThreadGroupID contextID = iota ) // CredentialsFromContext returns a copy of the Credentials used by ctx, or a @@ -35,6 +39,15 @@ func CredentialsFromContext(ctx context.Context) *Credentials { return NewAnonymousCredentials() } +// ThreadGroupIDFromContext returns the current thread group ID when ctx +// represents a task context. +func ThreadGroupIDFromContext(ctx context.Context) (tgid int32, ok bool) { + if tgid := ctx.Value(CtxThreadGroupID); tgid != nil { + return tgid.(int32), true + } + return 0, false +} + // ContextWithCredentials returns a copy of ctx carrying creds. func ContextWithCredentials(ctx context.Context, creds *Credentials) context.Context { return &authContext{ctx, creds} diff --git a/pkg/sentry/kernel/mq/mq.go b/pkg/sentry/kernel/mq/mq.go index 07482decf..7515a2772 100644 --- a/pkg/sentry/kernel/mq/mq.go +++ b/pkg/sentry/kernel/mq/mq.go @@ -399,7 +399,7 @@ func (q *Queue) Flush(ctx context.Context) { q.mu.Lock() defer q.mu.Unlock() - pid, ok := context.ThreadGroupIDFromContext(ctx) + pid, ok := auth.ThreadGroupIDFromContext(ctx) if ok { if q.subscriber != nil && pid == q.subscriber.pid { q.subscriber = nil diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index ab938fa3c..bb9a129ab 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -444,7 +444,7 @@ func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ hostarch. s.mu.Lock() defer s.mu.Unlock() s.attachTime = ktime.NowFromContext(ctx) - if pid, ok := context.ThreadGroupIDFromContext(ctx); ok { + if pid, ok := auth.ThreadGroupIDFromContext(ctx); ok { s.lastAttachDetachPID = pid } else { // AddMapping is called during a syscall, so ctx should always be a task @@ -468,7 +468,7 @@ func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ hostar // If called from a non-task context we also won't have a threadgroup // id. Silently skip updating the lastAttachDetachPid in that case. - if pid, ok := context.ThreadGroupIDFromContext(ctx); ok { + if pid, ok := auth.ThreadGroupIDFromContext(ctx); ok { s.lastAttachDetachPID = pid } else { log.Debugf("Couldn't obtain pid when removing mapping to %s, not updating the last detach pid.", s.debugLocked()) diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index cb9bcd7c0..ce38d9342 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -86,7 +86,7 @@ func (t *Task) contextValue(key interface{}, isTaskGoroutine bool) interface{} { return t case auth.CtxCredentials: return t.creds.Load() - case context.CtxThreadGroupID: + case auth.CtxThreadGroupID: return int32(t.tg.ID()) case fs.CtxRoot: if !isTaskGoroutine { |