diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 89 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 49 |
2 files changed, 111 insertions, 27 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 52fd1bfa3..e9c5099ea 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -96,6 +96,17 @@ type listenContext struct { hasher hash.Hash v6only bool netProto tcpip.NetworkProtocolNumber + // pendingMu protects pendingEndpoints. This should only be accessed + // by the listening endpoint's worker goroutine. + // + // Lock Ordering: listenEP.workerMu -> pendingMu + pendingMu sync.Mutex + // pending is used to wait for all pendingEndpoints to finish when + // a socket is closed. + pending sync.WaitGroup + // pendingEndpoints is a map of all endpoints for which a handshake is + // in progress. + pendingEndpoints map[stack.TransportEndpointID]*endpoint } // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. @@ -133,14 +144,15 @@ func decSynRcvdCount() { } // newListenContext creates a new listen context. -func newListenContext(stack *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { +func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ - stack: stack, - rcvWnd: rcvWnd, - hasher: sha1.New(), - v6only: v6only, - netProto: netProto, - listenEP: listenEP, + stack: stk, + rcvWnd: rcvWnd, + hasher: sha1.New(), + v6only: v6only, + netProto: netProto, + listenEP: listenEP, + pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), } rand.Read(l.nonce[0][:]) @@ -253,6 +265,17 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } + // listenEP is nil when listenContext is used by tcp.Forwarder. + if l.listenEP != nil { + l.listenEP.mu.Lock() + if l.listenEP.state != StateListen { + l.listenEP.mu.Unlock() + return nil, tcpip.ErrConnectionAborted + } + l.addPendingEndpoint(ep) + l.listenEP.mu.Unlock() + } + // Perform the 3-way handshake. h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) @@ -260,6 +283,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head if err := h.execute(); err != nil { ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() ep.Close() + if l.listenEP != nil { + l.removePendingEndpoint(ep) + } return nil, err } ep.mu.Lock() @@ -274,15 +300,41 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return ep, nil } +func (l *listenContext) addPendingEndpoint(n *endpoint) { + l.pendingMu.Lock() + l.pendingEndpoints[n.id] = n + l.pending.Add(1) + l.pendingMu.Unlock() +} + +func (l *listenContext) removePendingEndpoint(n *endpoint) { + l.pendingMu.Lock() + delete(l.pendingEndpoints, n.id) + l.pending.Done() + l.pendingMu.Unlock() +} + +func (l *listenContext) closeAllPendingEndpoints() { + l.pendingMu.Lock() + for _, n := range l.pendingEndpoints { + n.notifyProtocolGoroutine(notifyClose) + } + l.pendingMu.Unlock() + l.pending.Wait() +} + // deliverAccepted delivers the newly-accepted endpoint to the listener. If the // endpoint has transitioned out of the listen state, the new endpoint is closed // instead. func (e *endpoint) deliverAccepted(n *endpoint) { - e.mu.RLock() + e.mu.Lock() state := e.state - e.mu.RUnlock() + e.pendingAccepted.Add(1) + defer e.pendingAccepted.Done() + acceptedChan := e.acceptedChan + e.mu.Unlock() if state == StateListen { - e.acceptedChan <- n + acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) } else { n.Close() @@ -304,7 +356,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header e.stack.Stats().TCP.FailedConnectionAttempts.Increment() return } - + ctx.removePendingEndpoint(n) e.deliverAccepted(n) } @@ -451,6 +503,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // protocolListenLoop is the main loop of a listening TCP endpoint. It runs in // its own goroutine and is responsible for handling connection requests. func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { + e.mu.Lock() + v6only := e.v6only + e.mu.Unlock() + ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) + defer func() { // Mark endpoint as closed. This will prevent goroutines running // handleSynSegment() from attempting to queue new connections @@ -458,6 +515,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() e.state = StateClose + // close any endpoints in SYN-RCVD state. + ctx.closeAllPendingEndpoints() + // Do cleanup if needed. e.completeWorkerLocked() @@ -470,12 +530,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) }() - e.mu.Lock() - v6only := e.v6only - e.mu.Unlock() - - ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) - s := sleep.Sleeper{} s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) @@ -492,7 +546,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.handleListenSegment(ctx, s) s.decRef() } - synRcvdCount.pending.Wait() close(e.drainDone) <-e.undrain } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 353e2efaf..0e16877e7 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -362,6 +362,12 @@ type endpoint struct { // without hearing a response, the connection is closed. keepalive keepalive + // pendingAccepted is a synchronization primitive used to track number + // of connections that are queued up to be delivered to the accepted + // channel. We use this to ensure that all goroutines blocked on writing + // to the acceptedChan below terminate before we close acceptedChan. + pendingAccepted sync.WaitGroup `state:"nosave"` + // acceptedChan is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. @@ -375,7 +381,11 @@ type endpoint struct { // The goroutine drain completion notification channel. drainDone chan struct{} `state:"nosave"` - // The goroutine undrain notification channel. + // The goroutine undrain notification channel. This is currently used as + // a way to block the worker goroutines. Today nothing closes/writes + // this channel and this causes any goroutines waiting on this to just + // block. This is used during save/restore to prevent worker goroutines + // from mutating state as it's being saved. undrain chan struct{} `state:"nosave"` // probe if not nil is invoked on every received segment. It is passed @@ -575,6 +585,34 @@ func (e *endpoint) Close() { e.mu.Unlock() } +// closePendingAcceptableConnections closes all connections that have completed +// handshake but not yet been delivered to the application. +func (e *endpoint) closePendingAcceptableConnectionsLocked() { + done := make(chan struct{}) + // Spin a goroutine up as ranging on e.acceptedChan will just block when + // there are no more connections in the channel. Using a non-blocking + // select does not work as it can potentially select the default case + // even when there are pending writes but that are not yet written to + // the channel. + go func() { + defer close(done) + for n := range e.acceptedChan { + n.mu.Lock() + n.resetConnectionLocked(tcpip.ErrConnectionAborted) + n.mu.Unlock() + n.Close() + } + }() + // pendingAccepted(see endpoint.deliverAccepted) tracks the number of + // endpoints which have completed handshake but are not yet written to + // the e.acceptedChan. We wait here till the goroutine above can drain + // all such connections from e.acceptedChan. + e.pendingAccepted.Wait() + close(e.acceptedChan) + <-done + e.acceptedChan = nil +} + // cleanupLocked frees all resources associated with the endpoint. It is called // after Close() is called and the worker goroutine (if any) is done with its // work. @@ -582,14 +620,7 @@ func (e *endpoint) cleanupLocked() { // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { - close(e.acceptedChan) - for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() - n.Close() - } - e.acceptedChan = nil + e.closePendingAcceptableConnectionsLocked() } e.workerCleanup = false |