summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go65
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go38
2 files changed, 98 insertions, 5 deletions
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index 232a276dc..c134931cd 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -16,6 +16,7 @@
package semaphore
import (
+ "fmt"
"sync"
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
@@ -75,7 +76,10 @@ type Set struct {
perms fs.FilePermissions
opTime ktime.Time
changeTime ktime.Time
- sems []sem
+
+ // sems holds all semaphores in the set. The slice itself is immutable after
+ // it's been set, however each 'sem' object in the slice requires 'mu' lock.
+ sems []sem
// dead is set to true when the set is removed and can't be reached anymore.
// All waiters must wake up and fail when set is dead.
@@ -136,7 +140,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
}
// Validate parameters.
- if nsems > int32(set.size()) {
+ if nsems > int32(set.Size()) {
return nil, syserror.EINVAL
}
if create && exclusive {
@@ -244,19 +248,20 @@ func (r *Registry) findByKey(key int32) *Set {
func (r *Registry) totalSems() int {
totalSems := 0
for _, v := range r.semaphores {
- totalSems += v.size()
+ totalSems += v.Size()
}
return totalSems
}
func (s *Set) findSem(num int32) *sem {
- if num < 0 || int(num) >= s.size() {
+ if num < 0 || int(num) >= s.Size() {
return nil
}
return &s.sems[num]
}
-func (s *Set) size() int {
+// Size returns the number of semaphores in the set. Size is immutable.
+func (s *Set) Size() int {
return len(s.sems)
}
@@ -303,6 +308,39 @@ func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Cred
return nil
}
+// SetValAll overrides all semaphores values, waking up waiters as needed.
+//
+// 'len(vals)' must be equal to 's.Size()'.
+func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credentials) error {
+ if len(vals) != s.Size() {
+ panic(fmt.Sprintf("vals length (%d) different that Set.Size() (%d)", len(vals), s.Size()))
+ }
+
+ for _, val := range vals {
+ if val < 0 || val > valueMax {
+ return syserror.ERANGE
+ }
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The calling process must have alter permission on the semaphore set."
+ if !s.checkPerms(creds, fs.PermMask{Write: true}) {
+ return syserror.EACCES
+ }
+
+ for i, val := range vals {
+ sem := &s.sems[i]
+
+ // TODO: Clear undo entries in all processes
+ sem.value = int16(val)
+ sem.wakeWaiters()
+ }
+ s.changeTime = ktime.NowFromContext(ctx)
+ return nil
+}
+
// GetVal returns a semaphore value.
func (s *Set) GetVal(num int32, creds *auth.Credentials) (int16, error) {
s.mu.Lock()
@@ -320,6 +358,23 @@ func (s *Set) GetVal(num int32, creds *auth.Credentials) (int16, error) {
return sem.value, nil
}
+// GetValAll returns value for all semaphores.
+func (s *Set) GetValAll(creds *auth.Credentials) ([]uint16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The calling process must have read permission on the semaphore set."
+ if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ return nil, syserror.EACCES
+ }
+
+ vals := make([]uint16, s.Size())
+ for i, sem := range s.sems {
+ vals[i] = uint16(sem.value)
+ }
+ return vals, nil
+}
+
// ExecuteOps attempts to execute a list of operations to the set. It only
// succeeds when all operations can be applied. No changes are made if it fails.
//
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index 4ed52c4a7..6775725ca 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -22,6 +22,7 @@ import (
"gvisor.googlesource.com/gvisor/pkg/sentry/fs"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
"gvisor.googlesource.com/gvisor/pkg/syserror"
)
@@ -97,10 +98,18 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
return 0, nil, setVal(t, id, num, int16(val))
+ case linux.SETALL:
+ array := args[3].Pointer()
+ return 0, nil, setValAll(t, id, array)
+
case linux.GETVAL:
v, err := getVal(t, id, num)
return uintptr(v), nil, err
+ case linux.GETALL:
+ array := args[3].Pointer()
+ return 0, nil, getValAll(t, id, array)
+
case linux.IPC_RMID:
return 0, nil, remove(t, id)
@@ -155,6 +164,20 @@ func setVal(t *kernel.Task, id int32, num int32, val int16) error {
return set.SetVal(t, num, val, creds)
}
+func setValAll(t *kernel.Task, id int32, array usermem.Addr) error {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return syserror.EINVAL
+ }
+ vals := make([]uint16, set.Size())
+ if _, err := t.CopyIn(array, vals); err != nil {
+ return err
+ }
+ creds := auth.CredentialsFromContext(t)
+ return set.SetValAll(t, vals, creds)
+}
+
func getVal(t *kernel.Task, id int32, num int32) (int16, error) {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
@@ -164,3 +187,18 @@ func getVal(t *kernel.Task, id int32, num int32) (int16, error) {
creds := auth.CredentialsFromContext(t)
return set.GetVal(num, creds)
}
+
+func getValAll(t *kernel.Task, id int32, array usermem.Addr) error {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ vals, err := set.GetValAll(creds)
+ if err != nil {
+ return err
+ }
+ _, err = t.CopyOut(array, vals)
+ return err
+}