diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp/accept.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 89 |
1 files changed, 71 insertions, 18 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 52fd1bfa3..e9c5099ea 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -96,6 +96,17 @@ type listenContext struct { hasher hash.Hash v6only bool netProto tcpip.NetworkProtocolNumber + // pendingMu protects pendingEndpoints. This should only be accessed + // by the listening endpoint's worker goroutine. + // + // Lock Ordering: listenEP.workerMu -> pendingMu + pendingMu sync.Mutex + // pending is used to wait for all pendingEndpoints to finish when + // a socket is closed. + pending sync.WaitGroup + // pendingEndpoints is a map of all endpoints for which a handshake is + // in progress. + pendingEndpoints map[stack.TransportEndpointID]*endpoint } // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. @@ -133,14 +144,15 @@ func decSynRcvdCount() { } // newListenContext creates a new listen context. -func newListenContext(stack *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { +func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ - stack: stack, - rcvWnd: rcvWnd, - hasher: sha1.New(), - v6only: v6only, - netProto: netProto, - listenEP: listenEP, + stack: stk, + rcvWnd: rcvWnd, + hasher: sha1.New(), + v6only: v6only, + netProto: netProto, + listenEP: listenEP, + pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), } rand.Read(l.nonce[0][:]) @@ -253,6 +265,17 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } + // listenEP is nil when listenContext is used by tcp.Forwarder. + if l.listenEP != nil { + l.listenEP.mu.Lock() + if l.listenEP.state != StateListen { + l.listenEP.mu.Unlock() + return nil, tcpip.ErrConnectionAborted + } + l.addPendingEndpoint(ep) + l.listenEP.mu.Unlock() + } + // Perform the 3-way handshake. h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) @@ -260,6 +283,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head if err := h.execute(); err != nil { ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() ep.Close() + if l.listenEP != nil { + l.removePendingEndpoint(ep) + } return nil, err } ep.mu.Lock() @@ -274,15 +300,41 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return ep, nil } +func (l *listenContext) addPendingEndpoint(n *endpoint) { + l.pendingMu.Lock() + l.pendingEndpoints[n.id] = n + l.pending.Add(1) + l.pendingMu.Unlock() +} + +func (l *listenContext) removePendingEndpoint(n *endpoint) { + l.pendingMu.Lock() + delete(l.pendingEndpoints, n.id) + l.pending.Done() + l.pendingMu.Unlock() +} + +func (l *listenContext) closeAllPendingEndpoints() { + l.pendingMu.Lock() + for _, n := range l.pendingEndpoints { + n.notifyProtocolGoroutine(notifyClose) + } + l.pendingMu.Unlock() + l.pending.Wait() +} + // deliverAccepted delivers the newly-accepted endpoint to the listener. If the // endpoint has transitioned out of the listen state, the new endpoint is closed // instead. func (e *endpoint) deliverAccepted(n *endpoint) { - e.mu.RLock() + e.mu.Lock() state := e.state - e.mu.RUnlock() + e.pendingAccepted.Add(1) + defer e.pendingAccepted.Done() + acceptedChan := e.acceptedChan + e.mu.Unlock() if state == StateListen { - e.acceptedChan <- n + acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) } else { n.Close() @@ -304,7 +356,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header e.stack.Stats().TCP.FailedConnectionAttempts.Increment() return } - + ctx.removePendingEndpoint(n) e.deliverAccepted(n) } @@ -451,6 +503,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // protocolListenLoop is the main loop of a listening TCP endpoint. It runs in // its own goroutine and is responsible for handling connection requests. func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { + e.mu.Lock() + v6only := e.v6only + e.mu.Unlock() + ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) + defer func() { // Mark endpoint as closed. This will prevent goroutines running // handleSynSegment() from attempting to queue new connections @@ -458,6 +515,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() e.state = StateClose + // close any endpoints in SYN-RCVD state. + ctx.closeAllPendingEndpoints() + // Do cleanup if needed. e.completeWorkerLocked() @@ -470,12 +530,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) }() - e.mu.Lock() - v6only := e.v6only - e.mu.Unlock() - - ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) - s := sleep.Sleeper{} s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) @@ -492,7 +546,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.handleListenSegment(ctx, s) s.decRef() } - synRcvdCount.pending.Wait() close(e.drainDone) <-e.undrain } |