diff options
author | Bhasker Hariharan <bhaskerh@google.com> | 2020-03-25 10:54:19 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-03-25 10:55:22 -0700 |
commit | d04adebaab86ac30aca463b06528fc22430598ac (patch) | |
tree | ae3a08f54d9e3254a4f20916fb373d807a1e3822 /pkg/tcpip/transport/tcp/accept.go | |
parent | d8c4eff3f77b2a5dde34389033fe0ce4589b2e82 (diff) |
Fix data-race in endpoint.Readiness
PiperOrigin-RevId: 302924789
Diffstat (limited to 'pkg/tcpip/transport/tcp/accept.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 46 |
1 files changed, 28 insertions, 18 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index b4c4c8ab1..375ca21f6 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -365,21 +365,29 @@ func (l *listenContext) closeAllPendingEndpoints() { } // deliverAccepted delivers the newly-accepted endpoint to the listener. If the -// endpoint has transitioned out of the listen state, the new endpoint is closed -// instead. +// endpoint has transitioned out of the listen state (acceptedChan is nil), +// the new endpoint is closed instead. func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.Lock() - state := e.EndpointState() e.pendingAccepted.Add(1) - defer e.pendingAccepted.Done() - acceptedChan := e.acceptedChan e.mu.Unlock() + defer e.pendingAccepted.Done() - if state == StateListen { - acceptedChan <- n - e.waiterQueue.Notify(waiter.EventIn) - } else { - n.Close() + e.acceptMu.Lock() + for { + if e.acceptedChan == nil { + e.acceptMu.Unlock() + n.Close() + return + } + select { + case e.acceptedChan <- n: + e.acceptMu.Unlock() + e.waiterQueue.Notify(waiter.EventIn) + return + default: + e.acceptCond.Wait() + } } } @@ -420,11 +428,13 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header } func (e *endpoint) incSynRcvdCount() bool { - if e.synRcvdCount >= cap(e.acceptedChan) { - return false + e.acceptMu.Lock() + canInc := e.synRcvdCount < cap(e.acceptedChan) + e.acceptMu.Unlock() + if canInc { + e.synRcvdCount++ } - e.synRcvdCount++ - return true + return canInc } func (e *endpoint) decSynRcvdCount() { @@ -432,10 +442,10 @@ func (e *endpoint) decSynRcvdCount() { } func (e *endpoint) acceptQueueIsFull() bool { - if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c { - return true - } - return false + e.acceptMu.Lock() + full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan) + e.acceptMu.Unlock() + return full } // handleListenSegment is called when a listening endpoint receives a segment |