diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 109 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 4 |
2 files changed, 43 insertions, 70 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 62b8d9de9..d41e07521 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -20,7 +20,6 @@ import ( "fmt" "hash" "io" - "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sleep" @@ -103,14 +102,14 @@ type listenContext struct { // pendingMu protects pendingEndpoints. This should only be accessed // by the listening endpoint's worker goroutine. - // - // Lock Ordering: listenEP.workerMu -> pendingMu pendingMu sync.Mutex // pending is used to wait for all pendingEndpoints to finish when // a socket is closed. pending sync.WaitGroup // pendingEndpoints is a set of all endpoints for which a handshake is // in progress. + // + // +checklocks:pendingMu pendingEndpoints map[*endpoint]struct{} } @@ -265,7 +264,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu return nil, &tcpip.ErrConnectionAborted{} } - l.addPendingEndpoint(ep) // Propagate any inheritable options from the listening endpoint // to the newly created endpoint. @@ -275,8 +273,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu ep.mu.Unlock() ep.Close() - l.removePendingEndpoint(ep) - return nil, &tcpip.ErrConnectionAborted{} } @@ -295,10 +291,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu ep.mu.Unlock() ep.Close() - if l.listenEP != nil { - l.removePendingEndpoint(ep) - } - ep.drainClosingSegmentQueue() return nil, err @@ -336,38 +328,12 @@ func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions, return ep, nil } -func (l *listenContext) addPendingEndpoint(n *endpoint) { - l.pendingMu.Lock() - l.pendingEndpoints[n] = struct{}{} - l.pending.Add(1) - l.pendingMu.Unlock() -} - -func (l *listenContext) removePendingEndpoint(n *endpoint) { - l.pendingMu.Lock() - delete(l.pendingEndpoints, n) - l.pending.Done() - l.pendingMu.Unlock() -} - -func (l *listenContext) closeAllPendingEndpoints() { - l.pendingMu.Lock() - for n := range l.pendingEndpoints { - n.notifyProtocolGoroutine(notifyClose) - } - l.pendingMu.Unlock() - l.pending.Wait() -} - // +checklocks:h.ep.mu func (l *listenContext) cleanupFailedHandshake(h *handshake) { e := h.ep e.mu.Unlock() e.Close() e.notifyAborted() - if l.listenEP != nil { - l.removePendingEndpoint(e) - } e.drainClosingSegmentQueue() e.h = nil } @@ -378,9 +344,6 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) { // +checklocks:h.ep.mu func (l *listenContext) cleanupCompletedHandshake(h *handshake) { e := h.ep - if l.listenEP != nil { - l.removePendingEndpoint(e) - } e.isConnectNotified = true // Update the receive window scaling. We can't do it before the @@ -444,21 +407,6 @@ func (e *endpoint) notifyAborted() { e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } -func (e *endpoint) synRcvdBacklogFull() bool { - e.acceptMu.Lock() - acceptedCap := e.accepted.cap - e.acceptMu.Unlock() - // The capacity of the accepted queue would always be one greater than the - // listen backlog. But, the SYNRCVD connections count is always checked - // against the listen backlog value for Linux parity reason. - // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 - // - // We maintain an equality check here as the synRcvdCount is incremented - // and compared only from a single listener context and the capacity of - // the accepted queue can only increase by a new listen call. - return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1 -} - func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() full := e.accepted.acceptQueueIsFullLocked() @@ -500,34 +448,53 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil } - alwaysUseSynCookies := func() bool { + opts := parseSynSegmentOptions(s) + + useSynCookies, err := func() (bool, tcpip.Error) { var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) } - return bool(alwaysUseSynCookies) - }() - - opts := parseSynSegmentOptions(s) - if !alwaysUseSynCookies && !e.synRcvdBacklogFull() { - atomic.AddInt32(&e.synRcvdCount, 1) + if alwaysUseSynCookies { + return true, nil + } + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + + ctx.pendingMu.Lock() + defer ctx.pendingMu.Unlock() + // The capacity of the accepted queue would always be one greater than the + // listen backlog. But, the SYNRCVD connections count is always checked + // against the listen backlog value for Linux parity reason. + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 + if len(ctx.pendingEndpoints) == e.accepted.cap-1 { + return true, nil + } h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - atomic.AddInt32(&e.synRcvdCount, -1) - return err + return false, err } + ctx.pendingEndpoints[h.ep] = struct{}{} + ctx.pending.Add(1) + go func() { + defer func() { + ctx.pendingMu.Lock() + defer ctx.pendingMu.Unlock() + delete(ctx.pendingEndpoints, h.ep) + ctx.pending.Done() + }() + // Note that startHandshake returns a locked endpoint. The force call // here just makes it so. if err := h.complete(); err != nil { // +checklocksforce e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() ctx.cleanupFailedHandshake(h) - atomic.AddInt32(&e.synRcvdCount, -1) return } ctx.cleanupCompletedHandshake(h) @@ -558,7 +525,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } e.accepted.endpoints.PushBack(h.ep) - atomic.AddInt32(&e.synRcvdCount, -1) return true } }() @@ -570,6 +536,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } }() + return false, nil + }() + if err != nil { + return err + } + if !useSynCookies { return nil } @@ -780,7 +752,12 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.setEndpointState(StateClose) // Close any endpoints in SYN-RCVD state. - ctx.closeAllPendingEndpoints() + ctx.pendingMu.Lock() + for n := range ctx.pendingEndpoints { + n.notifyProtocolGoroutine(notifyClose) + } + ctx.pendingMu.Unlock() + ctx.pending.Wait() // Do cleanup if needed. e.completeWorkerLocked() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 7c8a11cfb..9c5b1b016 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -508,10 +508,6 @@ type endpoint struct { // and dropped when it is. segmentQueue segmentQueue `state:"wait"` - // synRcvdCount is the number of connections for this endpoint that are - // in SYN-RCVD state; this is only accessed atomically. - synRcvdCount int32 - // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. userMSS uint16 |