diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/sentry/kernel/semaphore/semaphore.go | 65 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/sys_sem.go | 38 |
2 files changed, 98 insertions, 5 deletions
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index 232a276dc..c134931cd 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -16,6 +16,7 @@ package semaphore import ( + "fmt" "sync" "gvisor.googlesource.com/gvisor/pkg/abi/linux" @@ -75,7 +76,10 @@ type Set struct { perms fs.FilePermissions opTime ktime.Time changeTime ktime.Time - sems []sem + + // sems holds all semaphores in the set. The slice itself is immutable after + // it's been set, however each 'sem' object in the slice requires 'mu' lock. + sems []sem // dead is set to true when the set is removed and can't be reached anymore. // All waiters must wake up and fail when set is dead. @@ -136,7 +140,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu } // Validate parameters. - if nsems > int32(set.size()) { + if nsems > int32(set.Size()) { return nil, syserror.EINVAL } if create && exclusive { @@ -244,19 +248,20 @@ func (r *Registry) findByKey(key int32) *Set { func (r *Registry) totalSems() int { totalSems := 0 for _, v := range r.semaphores { - totalSems += v.size() + totalSems += v.Size() } return totalSems } func (s *Set) findSem(num int32) *sem { - if num < 0 || int(num) >= s.size() { + if num < 0 || int(num) >= s.Size() { return nil } return &s.sems[num] } -func (s *Set) size() int { +// Size returns the number of semaphores in the set. Size is immutable. +func (s *Set) Size() int { return len(s.sems) } @@ -303,6 +308,39 @@ func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Cred return nil } +// SetValAll overrides all semaphores values, waking up waiters as needed. +// +// 'len(vals)' must be equal to 's.Size()'. +func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credentials) error { + if len(vals) != s.Size() { + panic(fmt.Sprintf("vals length (%d) different that Set.Size() (%d)", len(vals), s.Size())) + } + + for _, val := range vals { + if val < 0 || val > valueMax { + return syserror.ERANGE + } + } + + s.mu.Lock() + defer s.mu.Unlock() + + // "The calling process must have alter permission on the semaphore set." + if !s.checkPerms(creds, fs.PermMask{Write: true}) { + return syserror.EACCES + } + + for i, val := range vals { + sem := &s.sems[i] + + // TODO: Clear undo entries in all processes + sem.value = int16(val) + sem.wakeWaiters() + } + s.changeTime = ktime.NowFromContext(ctx) + return nil +} + // GetVal returns a semaphore value. func (s *Set) GetVal(num int32, creds *auth.Credentials) (int16, error) { s.mu.Lock() @@ -320,6 +358,23 @@ func (s *Set) GetVal(num int32, creds *auth.Credentials) (int16, error) { return sem.value, nil } +// GetValAll returns value for all semaphores. +func (s *Set) GetValAll(creds *auth.Credentials) ([]uint16, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // "The calling process must have read permission on the semaphore set." + if !s.checkPerms(creds, fs.PermMask{Read: true}) { + return nil, syserror.EACCES + } + + vals := make([]uint16, s.Size()) + for i, sem := range s.sems { + vals[i] = uint16(sem.value) + } + return vals, nil +} + // ExecuteOps attempts to execute a list of operations to the set. It only // succeeds when all operations can be applied. No changes are made if it fails. // diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index 4ed52c4a7..6775725ca 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -22,6 +22,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserror" ) @@ -97,10 +98,18 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } return 0, nil, setVal(t, id, num, int16(val)) + case linux.SETALL: + array := args[3].Pointer() + return 0, nil, setValAll(t, id, array) + case linux.GETVAL: v, err := getVal(t, id, num) return uintptr(v), nil, err + case linux.GETALL: + array := args[3].Pointer() + return 0, nil, getValAll(t, id, array) + case linux.IPC_RMID: return 0, nil, remove(t, id) @@ -155,6 +164,20 @@ func setVal(t *kernel.Task, id int32, num int32, val int16) error { return set.SetVal(t, num, val, creds) } +func setValAll(t *kernel.Task, id int32, array usermem.Addr) error { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByID(id) + if set == nil { + return syserror.EINVAL + } + vals := make([]uint16, set.Size()) + if _, err := t.CopyIn(array, vals); err != nil { + return err + } + creds := auth.CredentialsFromContext(t) + return set.SetValAll(t, vals, creds) +} + func getVal(t *kernel.Task, id int32, num int32) (int16, error) { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) @@ -164,3 +187,18 @@ func getVal(t *kernel.Task, id int32, num int32) (int16, error) { creds := auth.CredentialsFromContext(t) return set.GetVal(num, creds) } + +func getValAll(t *kernel.Task, id int32, array usermem.Addr) error { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByID(id) + if set == nil { + return syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + vals, err := set.GetValAll(creds) + if err != nil { + return err + } + _, err = t.CopyOut(array, vals) + return err +} |