From 3ac00fe9c3396f07a1416ff7fc855f6f9a3c4304 Mon Sep 17 00:00:00 2001 From: Jing Chen Date: Fri, 6 Nov 2020 18:36:03 -0800 Subject: Implement command GETNCNT for semctl. PiperOrigin-RevId: 341154192 --- pkg/sentry/kernel/semaphore/semaphore.go | 35 +++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) (limited to 'pkg/sentry/kernel/semaphore') diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index 310762936..b99c0bffa 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -103,6 +103,7 @@ type waiter struct { waiterEntry // value represents how much resource the waiter needs to wake up. + // The value is either 0 or negative. value int16 ch chan struct{} } @@ -423,8 +424,7 @@ func (s *Set) GetPID(num int32, creds *auth.Credentials) (int32, error) { return sem.pid, nil } -// GetZeroWaiters returns number of waiters waiting for the sem to go to zero. -func (s *Set) GetZeroWaiters(num int32, creds *auth.Credentials) (uint16, error) { +func (s *Set) countWaiters(num int32, creds *auth.Credentials, pred func(w *waiter) bool) (uint16, error) { s.mu.Lock() defer s.mu.Unlock() @@ -437,13 +437,27 @@ func (s *Set) GetZeroWaiters(num int32, creds *auth.Credentials) (uint16, error) if sem == nil { return 0, syserror.ERANGE } - var semzcnt uint16 + var cnt uint16 for w := sem.waiters.Front(); w != nil; w = w.Next() { - if w.value == 0 { - semzcnt++ + if pred(w) { + cnt++ } } - return semzcnt, nil + return cnt, nil +} + +// CountZeroWaiters returns number of waiters waiting for the sem's value to increase. +func (s *Set) CountZeroWaiters(num int32, creds *auth.Credentials) (uint16, error) { + return s.countWaiters(num, creds, func(w *waiter) bool { + return w.value == 0 + }) +} + +// CountNegativeWaiters returns number of waiters waiting for the sem to go to zero. +func (s *Set) CountNegativeWaiters(num int32, creds *auth.Credentials) (uint16, error) { + return s.countWaiters(num, creds, func(w *waiter) bool { + return w.value < 0 + }) } // ExecuteOps attempts to execute a list of operations to the set. It only @@ -598,11 +612,18 @@ func (s *Set) destroy() { } } +func abs(val int16) int16 { + if val < 0 { + return -val + } + return val +} + // wakeWaiters goes over all waiters and checks which of them can be notified. func (s *sem) wakeWaiters() { // Note that this will release all waiters waiting for 0 too. for w := s.waiters.Front(); w != nil; { - if s.value < w.value { + if s.value < abs(w.value) { // Still blocked, skip it. w = w.Next() continue -- cgit v1.2.3