summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorRahat Mahmood <rahat@google.com>2018-12-10 12:47:20 -0800
committerShentubot <shentubot@google.com>2018-12-10 12:48:02 -0800
commitfc297702511edef4760c4f7a1d89cc6f02347d50 (patch)
tree8bdcb8a8086af0de4db2525b1ea1327449e89bf1 /pkg
parent99d595869332f817de8f570fae184658c513a43c (diff)
Add type safety to shm ids and keys.
PiperOrigin-RevId: 224864380 Change-Id: I49542279ad56bf15ba462d3de1ef2b157b31830a
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/kernel/shm/shm.go24
-rw-r--r--pkg/sentry/syscalls/linux/sys_shm.go8
2 files changed, 19 insertions, 13 deletions
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index f760f5f76..4343dee13 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -51,6 +51,12 @@ import (
"gvisor.googlesource.com/gvisor/pkg/syserror"
)
+// Key represents a shm segment key. Analogous to a file name.
+type Key int32
+
+// ID represents the opaque handle for a shm segment. Analogous to an fd.
+type ID int32
+
// Registry tracks all shared memory segments in an IPC namespace. The registry
// provides the mechanisms for creating and finding segments, and reporting
// global shm parameters.
@@ -63,33 +69,33 @@ type Registry struct {
mu sync.Mutex `state:"nosave"`
// shms maps segment ids to segments. Protected by mu.
- shms map[int32]*Shm
+ shms map[ID]*Shm
// Sum of the sizes of all existing segments rounded up to page size, in
// units of page size. Protected by mu.
totalPages uint64
// lastIDUsed is protected by mu.
- lastIDUsed int32
+ lastIDUsed ID
}
// NewRegistry creates a new shm registry.
func NewRegistry(userNS *auth.UserNamespace) *Registry {
return &Registry{
userNS: userNS,
- shms: make(map[int32]*Shm),
+ shms: make(map[ID]*Shm),
}
}
// FindByID looks up a segment given an ID.
-func (r *Registry) FindByID(id int32) *Shm {
+func (r *Registry) FindByID(id ID) *Shm {
r.mu.Lock()
defer r.mu.Unlock()
return r.shms[id]
}
// Precondition: Caller must hold r.mu.
-func (r *Registry) findByKey(key int32) *Shm {
+func (r *Registry) findByKey(key Key) *Shm {
for _, v := range r.shms {
if v.key == key {
return v
@@ -100,7 +106,7 @@ func (r *Registry) findByKey(key int32) *Shm {
// FindOrCreate looks up or creates a segment in the registry. It's functionally
// analogous to open(2).
-func (r *Registry) FindOrCreate(ctx context.Context, pid, key int32, size uint64, mode linux.FileMode, private, create, exclusive bool) (*Shm, error) {
+func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size uint64, mode linux.FileMode, private, create, exclusive bool) (*Shm, error) {
if (create || private) && (size < linux.SHMMIN || size > linux.SHMMAX) {
// "A new segment was to be created and size is less than SHMMIN or
// greater than SHMMAX." - man shmget(2)
@@ -178,7 +184,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid, key int32, size uint64
}
// newShm creates a new segment in the registry.
-func (r *Registry) newShm(ctx context.Context, pid, key int32, creator fs.FileOwner, perms fs.FilePermissions, size uint64) (*Shm, error) {
+func (r *Registry) newShm(ctx context.Context, pid int32, key Key, creator fs.FileOwner, perms fs.FilePermissions, size uint64) (*Shm, error) {
p := platform.FromContext(ctx)
if p == nil {
panic(fmt.Sprintf("context.Context %T lacks non-nil value for key %T", ctx, platform.CtxPlatform))
@@ -289,7 +295,7 @@ type Shm struct {
registry *Registry
// ID is the kernel identifier for this segment. Immutable.
- ID int32
+ ID ID
// creator is the user that created the segment. Immutable.
creator fs.FileOwner
@@ -309,7 +315,7 @@ type Shm struct {
fr platform.FileRange
// key is the public identifier for this segment.
- key int32
+ key Key
// mu protects all fields below.
mu sync.Mutex `state:"nosave"`
diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go
index 8753c2e58..a0d3a73c5 100644
--- a/pkg/sentry/syscalls/linux/sys_shm.go
+++ b/pkg/sentry/syscalls/linux/sys_shm.go
@@ -24,7 +24,7 @@ import (
// Shmget implements shmget(2).
func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- key := args[0].Int()
+ key := shm.Key(args[0].Int())
size := uint64(args[1].SizeT())
flag := args[2].Int()
@@ -43,7 +43,7 @@ func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
// findSegment retrives a shm segment by the given id.
-func findSegment(t *kernel.Task, id int32) (*shm.Shm, error) {
+func findSegment(t *kernel.Task, id shm.ID) (*shm.Shm, error) {
r := t.IPCNamespace().ShmRegistry()
segment := r.FindByID(id)
if segment == nil {
@@ -55,7 +55,7 @@ func findSegment(t *kernel.Task, id int32) (*shm.Shm, error) {
// Shmat implements shmat(2).
func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- id := args[0].Int()
+ id := shm.ID(args[0].Int())
addr := args[1].Pointer()
flag := args[2].Int()
@@ -86,7 +86,7 @@ func Shmdt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Shmctl implements shmctl(2).
func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- id := args[0].Int()
+ id := shm.ID(args[0].Int())
cmd := args[1].Int()
buf := args[2].Pointer()