diff options
-rw-r--r-- | pkg/sentry/kernel/shm/shm.go | 24 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/sys_shm.go | 8 |
2 files changed, 19 insertions, 13 deletions
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index f760f5f76..4343dee13 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -51,6 +51,12 @@ import ( "gvisor.googlesource.com/gvisor/pkg/syserror" ) +// Key represents a shm segment key. Analogous to a file name. +type Key int32 + +// ID represents the opaque handle for a shm segment. Analogous to an fd. +type ID int32 + // Registry tracks all shared memory segments in an IPC namespace. The registry // provides the mechanisms for creating and finding segments, and reporting // global shm parameters. @@ -63,33 +69,33 @@ type Registry struct { mu sync.Mutex `state:"nosave"` // shms maps segment ids to segments. Protected by mu. - shms map[int32]*Shm + shms map[ID]*Shm // Sum of the sizes of all existing segments rounded up to page size, in // units of page size. Protected by mu. totalPages uint64 // lastIDUsed is protected by mu. - lastIDUsed int32 + lastIDUsed ID } // NewRegistry creates a new shm registry. func NewRegistry(userNS *auth.UserNamespace) *Registry { return &Registry{ userNS: userNS, - shms: make(map[int32]*Shm), + shms: make(map[ID]*Shm), } } // FindByID looks up a segment given an ID. -func (r *Registry) FindByID(id int32) *Shm { +func (r *Registry) FindByID(id ID) *Shm { r.mu.Lock() defer r.mu.Unlock() return r.shms[id] } // Precondition: Caller must hold r.mu. -func (r *Registry) findByKey(key int32) *Shm { +func (r *Registry) findByKey(key Key) *Shm { for _, v := range r.shms { if v.key == key { return v @@ -100,7 +106,7 @@ func (r *Registry) findByKey(key int32) *Shm { // FindOrCreate looks up or creates a segment in the registry. It's functionally // analogous to open(2). -func (r *Registry) FindOrCreate(ctx context.Context, pid, key int32, size uint64, mode linux.FileMode, private, create, exclusive bool) (*Shm, error) { +func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size uint64, mode linux.FileMode, private, create, exclusive bool) (*Shm, error) { if (create || private) && (size < linux.SHMMIN || size > linux.SHMMAX) { // "A new segment was to be created and size is less than SHMMIN or // greater than SHMMAX." - man shmget(2) @@ -178,7 +184,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid, key int32, size uint64 } // newShm creates a new segment in the registry. -func (r *Registry) newShm(ctx context.Context, pid, key int32, creator fs.FileOwner, perms fs.FilePermissions, size uint64) (*Shm, error) { +func (r *Registry) newShm(ctx context.Context, pid int32, key Key, creator fs.FileOwner, perms fs.FilePermissions, size uint64) (*Shm, error) { p := platform.FromContext(ctx) if p == nil { panic(fmt.Sprintf("context.Context %T lacks non-nil value for key %T", ctx, platform.CtxPlatform)) @@ -289,7 +295,7 @@ type Shm struct { registry *Registry // ID is the kernel identifier for this segment. Immutable. - ID int32 + ID ID // creator is the user that created the segment. Immutable. creator fs.FileOwner @@ -309,7 +315,7 @@ type Shm struct { fr platform.FileRange // key is the public identifier for this segment. - key int32 + key Key // mu protects all fields below. mu sync.Mutex `state:"nosave"` diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go index 8753c2e58..a0d3a73c5 100644 --- a/pkg/sentry/syscalls/linux/sys_shm.go +++ b/pkg/sentry/syscalls/linux/sys_shm.go @@ -24,7 +24,7 @@ import ( // Shmget implements shmget(2). func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - key := args[0].Int() + key := shm.Key(args[0].Int()) size := uint64(args[1].SizeT()) flag := args[2].Int() @@ -43,7 +43,7 @@ func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // findSegment retrives a shm segment by the given id. -func findSegment(t *kernel.Task, id int32) (*shm.Shm, error) { +func findSegment(t *kernel.Task, id shm.ID) (*shm.Shm, error) { r := t.IPCNamespace().ShmRegistry() segment := r.FindByID(id) if segment == nil { @@ -55,7 +55,7 @@ func findSegment(t *kernel.Task, id int32) (*shm.Shm, error) { // Shmat implements shmat(2). func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - id := args[0].Int() + id := shm.ID(args[0].Int()) addr := args[1].Pointer() flag := args[2].Int() @@ -86,7 +86,7 @@ func Shmdt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Shmctl implements shmctl(2). func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - id := args[0].Int() + id := shm.ID(args[0].Int()) cmd := args[1].Int() buf := args[2].Pointer() |