From dfbc0b0a4cabc6468c82a7665ff655fd4a633dd9 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 6 Aug 2019 10:59:49 -0700 Subject: Fix for a panic due to writing to a closed accept channel. This can happen because endpoint.Close() closes the accept channel first and then drains/resets any accepted but not delivered connections. But there can be connections that are connected but not delivered to the channel as the channel was full. But closing the channel can cause these writes to fail with a write to a closed channel. The correct solution is to abort any connections in SYN-RCVD state and drain/abort all completed connections before closing the accept channel. PiperOrigin-RevId: 261951132 --- pkg/tcpip/transport/tcp/accept.go | 89 ++++++++++++++++++++++++------ pkg/tcpip/transport/tcp/dual_stack_test.go | 86 +++++++++++++++++++++++++++++ pkg/tcpip/transport/tcp/endpoint.go | 49 +++++++++++++--- 3 files changed, 197 insertions(+), 27 deletions(-) (limited to 'pkg/tcpip/transport') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 52fd1bfa3..e9c5099ea 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -96,6 +96,17 @@ type listenContext struct { hasher hash.Hash v6only bool netProto tcpip.NetworkProtocolNumber + // pendingMu protects pendingEndpoints. This should only be accessed + // by the listening endpoint's worker goroutine. + // + // Lock Ordering: listenEP.workerMu -> pendingMu + pendingMu sync.Mutex + // pending is used to wait for all pendingEndpoints to finish when + // a socket is closed. + pending sync.WaitGroup + // pendingEndpoints is a map of all endpoints for which a handshake is + // in progress. + pendingEndpoints map[stack.TransportEndpointID]*endpoint } // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. @@ -133,14 +144,15 @@ func decSynRcvdCount() { } // newListenContext creates a new listen context. -func newListenContext(stack *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { +func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ - stack: stack, - rcvWnd: rcvWnd, - hasher: sha1.New(), - v6only: v6only, - netProto: netProto, - listenEP: listenEP, + stack: stk, + rcvWnd: rcvWnd, + hasher: sha1.New(), + v6only: v6only, + netProto: netProto, + listenEP: listenEP, + pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), } rand.Read(l.nonce[0][:]) @@ -253,6 +265,17 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } + // listenEP is nil when listenContext is used by tcp.Forwarder. + if l.listenEP != nil { + l.listenEP.mu.Lock() + if l.listenEP.state != StateListen { + l.listenEP.mu.Unlock() + return nil, tcpip.ErrConnectionAborted + } + l.addPendingEndpoint(ep) + l.listenEP.mu.Unlock() + } + // Perform the 3-way handshake. h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) @@ -260,6 +283,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head if err := h.execute(); err != nil { ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() ep.Close() + if l.listenEP != nil { + l.removePendingEndpoint(ep) + } return nil, err } ep.mu.Lock() @@ -274,15 +300,41 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return ep, nil } +func (l *listenContext) addPendingEndpoint(n *endpoint) { + l.pendingMu.Lock() + l.pendingEndpoints[n.id] = n + l.pending.Add(1) + l.pendingMu.Unlock() +} + +func (l *listenContext) removePendingEndpoint(n *endpoint) { + l.pendingMu.Lock() + delete(l.pendingEndpoints, n.id) + l.pending.Done() + l.pendingMu.Unlock() +} + +func (l *listenContext) closeAllPendingEndpoints() { + l.pendingMu.Lock() + for _, n := range l.pendingEndpoints { + n.notifyProtocolGoroutine(notifyClose) + } + l.pendingMu.Unlock() + l.pending.Wait() +} + // deliverAccepted delivers the newly-accepted endpoint to the listener. If the // endpoint has transitioned out of the listen state, the new endpoint is closed // instead. func (e *endpoint) deliverAccepted(n *endpoint) { - e.mu.RLock() + e.mu.Lock() state := e.state - e.mu.RUnlock() + e.pendingAccepted.Add(1) + defer e.pendingAccepted.Done() + acceptedChan := e.acceptedChan + e.mu.Unlock() if state == StateListen { - e.acceptedChan <- n + acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) } else { n.Close() @@ -304,7 +356,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header e.stack.Stats().TCP.FailedConnectionAttempts.Increment() return } - + ctx.removePendingEndpoint(n) e.deliverAccepted(n) } @@ -451,6 +503,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // protocolListenLoop is the main loop of a listening TCP endpoint. It runs in // its own goroutine and is responsible for handling connection requests. func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { + e.mu.Lock() + v6only := e.v6only + e.mu.Unlock() + ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) + defer func() { // Mark endpoint as closed. This will prevent goroutines running // handleSynSegment() from attempting to queue new connections @@ -458,6 +515,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() e.state = StateClose + // close any endpoints in SYN-RCVD state. + ctx.closeAllPendingEndpoints() + // Do cleanup if needed. e.completeWorkerLocked() @@ -470,12 +530,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) }() - e.mu.Lock() - v6only := e.v6only - e.mu.Unlock() - - ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) - s := sleep.Sleeper{} s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) @@ -492,7 +546,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.handleListenSegment(ctx, s) s.decRef() } - synRcvdCount.pending.Wait() close(e.drainDone) <-e.undrain } diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index d9f79e8c5..c54610a87 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -570,3 +570,89 @@ func TestV4AcceptOnV4(t *testing.T) { // Test acceptance. testV4Accept(t, c) } + +func testV4ListenClose(t *testing.T, c *context.Context) { + // Set the SynRcvd threshold to zero to force a syn cookie based accept + // to happen. + saved := tcp.SynRcvdCountThreshold + defer func() { + tcp.SynRcvdCountThreshold = saved + }() + tcp.SynRcvdCountThreshold = 0 + const n = uint16(32) + + // Start listening. + if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + irs := seqnum.Value(789) + for i := uint16(0); i < n; i++ { + // Send a SYN request. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + i, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + } + + // Each of these ACK's will cause a syn-cookie based connection to be + // accepted and delivered to the listening endpoint. + for i := uint16(0); i < n; i++ { + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + iss := seqnum.Value(tcp.SequenceNumber()) + // Send ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: tcp.DestinationPort(), + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + } + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + nep, _, err := c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + nep, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(10 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + nep.Close() + c.EP.Close() +} + +func TestV4ListenCloseOnV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4ListenClose(t, c) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 353e2efaf..0e16877e7 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -362,6 +362,12 @@ type endpoint struct { // without hearing a response, the connection is closed. keepalive keepalive + // pendingAccepted is a synchronization primitive used to track number + // of connections that are queued up to be delivered to the accepted + // channel. We use this to ensure that all goroutines blocked on writing + // to the acceptedChan below terminate before we close acceptedChan. + pendingAccepted sync.WaitGroup `state:"nosave"` + // acceptedChan is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. @@ -375,7 +381,11 @@ type endpoint struct { // The goroutine drain completion notification channel. drainDone chan struct{} `state:"nosave"` - // The goroutine undrain notification channel. + // The goroutine undrain notification channel. This is currently used as + // a way to block the worker goroutines. Today nothing closes/writes + // this channel and this causes any goroutines waiting on this to just + // block. This is used during save/restore to prevent worker goroutines + // from mutating state as it's being saved. undrain chan struct{} `state:"nosave"` // probe if not nil is invoked on every received segment. It is passed @@ -575,6 +585,34 @@ func (e *endpoint) Close() { e.mu.Unlock() } +// closePendingAcceptableConnections closes all connections that have completed +// handshake but not yet been delivered to the application. +func (e *endpoint) closePendingAcceptableConnectionsLocked() { + done := make(chan struct{}) + // Spin a goroutine up as ranging on e.acceptedChan will just block when + // there are no more connections in the channel. Using a non-blocking + // select does not work as it can potentially select the default case + // even when there are pending writes but that are not yet written to + // the channel. + go func() { + defer close(done) + for n := range e.acceptedChan { + n.mu.Lock() + n.resetConnectionLocked(tcpip.ErrConnectionAborted) + n.mu.Unlock() + n.Close() + } + }() + // pendingAccepted(see endpoint.deliverAccepted) tracks the number of + // endpoints which have completed handshake but are not yet written to + // the e.acceptedChan. We wait here till the goroutine above can drain + // all such connections from e.acceptedChan. + e.pendingAccepted.Wait() + close(e.acceptedChan) + <-done + e.acceptedChan = nil +} + // cleanupLocked frees all resources associated with the endpoint. It is called // after Close() is called and the worker goroutine (if any) is done with its // work. @@ -582,14 +620,7 @@ func (e *endpoint) cleanupLocked() { // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { - close(e.acceptedChan) - for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() - n.Close() - } - e.acceptedChan = nil + e.closePendingAcceptableConnectionsLocked() } e.workerCleanup = false -- cgit v1.2.3