diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp/accept.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 143 |
1 files changed, 47 insertions, 96 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 7534aa89f..6b3238d6b 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -228,15 +228,11 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i return n } -// startHandshake creates a new endpoint in connecting state and then sends -// the SYN-ACK for the TCP 3-way handshake. It returns the state of the -// handshake in progress, which includes the new endpoint in the SYN-RCVD -// state. +// createEndpointAndPerformHandshake creates a new endpoint in connected state +// and then performs the TCP 3-way handshake. // -// On success, a handshake h is returned with h.ep.mu held. -// -// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. -func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) { +// The new endpoint is returned with e.mu held. +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) @@ -251,8 +247,10 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // listenEP is nil when listenContext is used by tcp.Forwarder. deferAccept := time.Duration(0) if l.listenEP != nil { + l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { + l.listenEP.mu.Unlock() // Ensure we release any registrations done by the newly // created endpoint. ep.mu.Unlock() @@ -270,12 +268,16 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q ep.mu.Unlock() ep.Close() - l.removePendingEndpoint(ep) + if l.listenEP != nil { + l.removePendingEndpoint(ep) + l.listenEP.mu.Unlock() + } return nil, tcpip.ErrConnectionAborted } deferAccept = l.listenEP.deferAccept + l.listenEP.mu.Unlock() } // Register new endpoint so that packets are routed to it. @@ -294,33 +296,28 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q ep.isRegistered = true - // Initialize and start the handshake. - h := ep.newPassiveHandshake(isn, irs, opts, deferAccept) - if err := h.start(); err != nil { - l.cleanupFailedHandshake(h) - return nil, err - } - return h, nil -} + // Perform the 3-way handshake. + h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) + if err := h.execute(); err != nil { + ep.mu.Unlock() + ep.Close() + ep.notifyAborted() -// performHandshake performs a TCP 3-way handshake. On success, the new -// established endpoint is returned with e.mu held. -// -// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. -func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { - h, err := l.startHandshake(s, opts, queue, owner) - if err != nil { - return nil, err - } - ep := h.ep + if l.listenEP != nil { + l.removePendingEndpoint(ep) + } + + ep.drainClosingSegmentQueue() - if err := h.complete(); err != nil { - ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() - ep.stats.FailedConnectionAttempts.Increment() - l.cleanupFailedHandshake(h) return nil, err } - l.cleanupCompletedHandshake(h) + ep.isConnectNotified = true + + // Update the receive window scaling. We can't do it before the + // handshake because it's possible that the peer doesn't support window + // scaling. + ep.rcv.rcvWndScale = h.effectiveRcvWndScale() + return ep, nil } @@ -347,39 +344,6 @@ func (l *listenContext) closeAllPendingEndpoints() { l.pending.Wait() } -// Precondition: h.ep.mu must be held. -func (l *listenContext) cleanupFailedHandshake(h *handshake) { - e := h.ep - e.mu.Unlock() - e.Close() - e.notifyAborted() - if l.listenEP != nil { - l.removePendingEndpoint(e) - } - e.drainClosingSegmentQueue() - e.h = nil -} - -// cleanupCompletedHandshake transfers any state from the completed handshake to -// the new endpoint. -// -// Precondition: h.ep.mu must be held. -func (l *listenContext) cleanupCompletedHandshake(h *handshake) { - e := h.ep - if l.listenEP != nil { - l.removePendingEndpoint(e) - } - e.isConnectNotified = true - - // Update the receive window scaling. We can't do it before the - // handshake because it's possible that the peer doesn't support window - // scaling. - e.rcv.rcvWndScale = e.h.effectiveRcvWndScale() - - // Clean up handshake state stored in the endpoint so that it can be GCed. - e.h = nil -} - // deliverAccepted delivers the newly-accepted endpoint to the listener. If the // endpoint has transitioned out of the listen state (acceptedChan is nil), // the new endpoint is closed instead. @@ -459,40 +423,23 @@ func (e *endpoint) notifyAborted() { // // A limited number of these goroutines are allowed before TCP starts using SYN // cookies to accept connections. -// -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. -func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error { +func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { + defer ctx.synRcvdCount.dec() defer s.decRef() - h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) + n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - e.synRcvdCount-- - return err + e.decSynRcvdCount() + return } + ctx.removePendingEndpoint(n) + e.decSynRcvdCount() + n.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - go func() { - defer ctx.synRcvdCount.dec() - if err := h.complete(); err != nil { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() - ctx.cleanupFailedHandshake(h) - e.mu.Lock() - e.synRcvdCount-- - e.mu.Unlock() - return - } - ctx.cleanupCompletedHandshake(h) - e.mu.Lock() - e.synRcvdCount-- - e.mu.Unlock() - h.ep.startAcceptedLoop() - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - e.deliverAccepted(h.ep) - }() // S/R-SAFE: synRcvdCount is the barrier. - - return nil + e.deliverAccepted(n) } func (e *endpoint) incSynRcvdCount() bool { @@ -505,6 +452,12 @@ func (e *endpoint) incSynRcvdCount() bool { return canInc } +func (e *endpoint) decSynRcvdCount() { + e.mu.Lock() + e.synRcvdCount-- + e.mu.Unlock() +} + func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan) @@ -514,8 +467,6 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. -// -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.rcvListMu.Lock() rcvClosed := e.rcvClosed @@ -540,7 +491,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // backlog. if !e.acceptQueueIsFull() && e.incSynRcvdCount() { s.incRef() - _ = e.handleSynSegment(ctx, s, &opts) + go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. return } ctx.synRcvdCount.dec() @@ -735,7 +686,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // to the endpoint. e.setEndpointState(StateClose) - // Close any endpoints in SYN-RCVD state. + // close any endpoints in SYN-RCVD state. ctx.closeAllPendingEndpoints() // Do cleanup if needed. |