diff options
Diffstat (limited to 'pkg/sentry/socket')
-rwxr-xr-x | pkg/sentry/socket/netstack/netstack.go | 24 | ||||
-rwxr-xr-x | pkg/sentry/socket/netstack/netstack_state_autogen.go | 2 |
2 files changed, 22 insertions, 4 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 13a9a60b4..a2e1da02f 100755 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -29,6 +29,7 @@ import ( "io" "math" "reflect" + "sync/atomic" "syscall" "time" @@ -264,6 +265,12 @@ type SocketOperations struct { skType linux.SockType protocol int + // readViewHasData is 1 iff readView has data to be read, 0 otherwise. + // Must be accessed using atomic operations. It must only be written + // with readMu held but can be read without holding readMu. The latter + // is required to avoid deadlocks in epoll Readiness checks. + readViewHasData uint32 + // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` // readView contains the remaining payload from the last packet. @@ -410,21 +417,24 @@ func (s *SocketOperations) isPacketBased() bool { // fetchReadView updates the readView field of the socket if it's currently // empty. It assumes that the socket is locked. +// +// Precondition: s.readMu must be held. func (s *SocketOperations) fetchReadView() *syserr.Error { if len(s.readView) > 0 { return nil } - s.readView = nil s.sender = tcpip.FullAddress{} v, cms, err := s.Endpoint.Read(&s.sender) if err != nil { + atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) } s.readView = v s.readCM = cms + atomic.StoreUint32(&s.readViewHasData, 1) return nil } @@ -623,11 +633,9 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { // Check our cached value iff the caller asked for readability and the // endpoint itself is currently not readable. if (mask & ^r & waiter.EventIn) != 0 { - s.readMu.Lock() - if len(s.readView) > 0 { + if atomic.LoadUint32(&s.readViewHasData) == 1 { r |= waiter.EventIn } - s.readMu.Unlock() } return r @@ -2334,6 +2342,10 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq } copied += n s.readView.TrimFront(n) + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) + } + dst = dst.DropFirst(n) if e != nil { err = syserr.FromError(e) @@ -2456,6 +2468,10 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readView.TrimFront(int(n)) } + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) + } + var flags int if msgLen > int(n) { flags |= linux.MSG_TRUNC diff --git a/pkg/sentry/socket/netstack/netstack_state_autogen.go b/pkg/sentry/socket/netstack/netstack_state_autogen.go index 608f23f63..9be6fe242 100755 --- a/pkg/sentry/socket/netstack/netstack_state_autogen.go +++ b/pkg/sentry/socket/netstack/netstack_state_autogen.go @@ -15,6 +15,7 @@ func (x *SocketOperations) save(m state.Map) { m.Save("Endpoint", &x.Endpoint) m.Save("skType", &x.skType) m.Save("protocol", &x.protocol) + m.Save("readViewHasData", &x.readViewHasData) m.Save("readView", &x.readView) m.Save("readCM", &x.readCM) m.Save("sender", &x.sender) @@ -32,6 +33,7 @@ func (x *SocketOperations) load(m state.Map) { m.Load("Endpoint", &x.Endpoint) m.Load("skType", &x.skType) m.Load("protocol", &x.protocol) + m.Load("readViewHasData", &x.readViewHasData) m.Load("readView", &x.readView) m.Load("readCM", &x.readCM) m.Load("sender", &x.sender) |