From 6410387ff9b4f0dbe88325ea0e30776f5f3efd5d Mon Sep 17 00:00:00 2001 From: Michael Pratt Date: Mon, 6 Jan 2020 09:27:35 -0800 Subject: Cleanup Shm reference handling Currently, shm.Registry.FindByID will return Shm instances without taking an additional reference on them, making it possible for them to disappear. More explicitly handle references. All callers hold a reference for the duration that they hold the instance. Registry.shms may transitively hold Shms with no references, so it must TryIncRef to determine if they are still valid. PiperOrigin-RevId: 288314529 --- pkg/sentry/syscalls/linux/sys_shm.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'pkg/sentry/syscalls') diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go index d57ffb3a1..4a8bc24a2 100644 --- a/pkg/sentry/syscalls/linux/sys_shm.go +++ b/pkg/sentry/syscalls/linux/sys_shm.go @@ -39,10 +39,13 @@ func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, err } + defer segment.DecRef() return uintptr(segment.ID), nil, nil } // findSegment retrives a shm segment by the given id. +// +// findSegment returns a reference on Shm. func findSegment(t *kernel.Task, id shm.ID) (*shm.Shm, error) { r := t.IPCNamespace().ShmRegistry() segment := r.FindByID(id) @@ -63,6 +66,7 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, syserror.EINVAL } + defer segment.DecRef() opts, err := segment.ConfigureAttach(t, addr, shm.AttachOpts{ Execute: flag&linux.SHM_EXEC == linux.SHM_EXEC, @@ -72,7 +76,6 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - defer segment.DecRef() addr, err = t.MemoryManager().MMap(t, opts) return uintptr(addr), nil, err } @@ -105,6 +108,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, syserror.EINVAL } + defer segment.DecRef() stat, err := segment.IPCStat(t) if err == nil { @@ -128,6 +132,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if err != nil { return 0, nil, syserror.EINVAL } + defer segment.DecRef() switch cmd { case linux.IPC_SET: -- cgit v1.2.3