summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp/accept.go
diff options
context:
space:
mode:
authorBhasker Hariharan <bhaskerh@google.com>2020-03-25 10:54:19 -0700
committergVisor bot <gvisor-bot@google.com>2020-03-25 10:55:22 -0700
commitd04adebaab86ac30aca463b06528fc22430598ac (patch)
treeae3a08f54d9e3254a4f20916fb373d807a1e3822 /pkg/tcpip/transport/tcp/accept.go
parentd8c4eff3f77b2a5dde34389033fe0ce4589b2e82 (diff)
Fix data-race in endpoint.Readiness
PiperOrigin-RevId: 302924789
Diffstat (limited to 'pkg/tcpip/transport/tcp/accept.go')
-rw-r--r--pkg/tcpip/transport/tcp/accept.go46
1 files changed, 28 insertions, 18 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index b4c4c8ab1..375ca21f6 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -365,21 +365,29 @@ func (l *listenContext) closeAllPendingEndpoints() {
}
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
-// endpoint has transitioned out of the listen state, the new endpoint is closed
-// instead.
+// endpoint has transitioned out of the listen state (acceptedChan is nil),
+// the new endpoint is closed instead.
func (e *endpoint) deliverAccepted(n *endpoint) {
e.mu.Lock()
- state := e.EndpointState()
e.pendingAccepted.Add(1)
- defer e.pendingAccepted.Done()
- acceptedChan := e.acceptedChan
e.mu.Unlock()
+ defer e.pendingAccepted.Done()
- if state == StateListen {
- acceptedChan <- n
- e.waiterQueue.Notify(waiter.EventIn)
- } else {
- n.Close()
+ e.acceptMu.Lock()
+ for {
+ if e.acceptedChan == nil {
+ e.acceptMu.Unlock()
+ n.Close()
+ return
+ }
+ select {
+ case e.acceptedChan <- n:
+ e.acceptMu.Unlock()
+ e.waiterQueue.Notify(waiter.EventIn)
+ return
+ default:
+ e.acceptCond.Wait()
+ }
}
}
@@ -420,11 +428,13 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
}
func (e *endpoint) incSynRcvdCount() bool {
- if e.synRcvdCount >= cap(e.acceptedChan) {
- return false
+ e.acceptMu.Lock()
+ canInc := e.synRcvdCount < cap(e.acceptedChan)
+ e.acceptMu.Unlock()
+ if canInc {
+ e.synRcvdCount++
}
- e.synRcvdCount++
- return true
+ return canInc
}
func (e *endpoint) decSynRcvdCount() {
@@ -432,10 +442,10 @@ func (e *endpoint) decSynRcvdCount() {
}
func (e *endpoint) acceptQueueIsFull() bool {
- if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c {
- return true
- }
- return false
+ e.acceptMu.Lock()
+ full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan)
+ e.acceptMu.Unlock()
+ return full
}
// handleListenSegment is called when a listening endpoint receives a segment