From c633a7f9d1c15c8f1639a95809d875576ac7707f Mon Sep 17 00:00:00 2001 From: Arthur Sfez Date: Tue, 21 Sep 2021 16:09:58 -0700 Subject: Deliver endpoints to the accept queue synchronously when possible Before this change, when a new connection was created after receiving an ACK that matched a SYN-cookie, it was always delivered asynchronously to the accept queue. There was a chance that the listening endpoint would process a SYN from another client before the delivery happened, and the listening endpoint would not know yet that the queue was about to be full, once the delivery happened. Now, when an ACK matching a SYN-cookie is received, the new endpoint is created and moved to the accept queue synchronously, while holding the accept lock. Fixes #6545 PiperOrigin-RevId: 398107254 --- pkg/tcpip/transport/tcp/accept.go | 106 ++++++++++++++++-------------- pkg/tcpip/transport/tcp/endpoint.go | 8 ++- pkg/tcpip/transport/tcp/endpoint_state.go | 2 + test/syscalls/linux/tcp_socket.cc | 60 +++++++++++++++++ 4 files changed, 123 insertions(+), 53 deletions(-) diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 03c9fafa1..ff0a5df9c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -401,43 +401,6 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { e.h = nil } -// deliverAccepted delivers the newly-accepted endpoint to the listener. If the -// listener has transitioned out of the listen state (accepted is the zero -// value), the new endpoint is reset instead. -func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) { - e.mu.Lock() - e.pendingAccepted.Add(1) - e.mu.Unlock() - defer e.pendingAccepted.Done() - - // Drop the lock before notifying to avoid deadlock in user-specified - // callbacks. - delivered := func() bool { - e.acceptMu.Lock() - defer e.acceptMu.Unlock() - for { - if e.accepted == (accepted{}) { - return false - } - if e.accepted.endpoints.Len() == e.accepted.cap { - e.acceptCond.Wait() - continue - } - - e.accepted.endpoints.PushBack(n) - if !withSynCookie { - atomic.AddInt32(&e.synRcvdCount, -1) - } - return true - } - }() - if delivered { - e.waiterQueue.Notify(waiter.ReadableEvents) - } else { - n.notifyProtocolGoroutine(notifyReset) - } -} - // propagateInheritableOptionsLocked propagates any options set on the listening // endpoint to the newly created endpoint. // @@ -521,7 +484,40 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header. ctx.cleanupCompletedHandshake(h) h.ep.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - e.deliverAccepted(h.ep, false /*withSynCookie*/) + + // Deliver the endpoint to the accept queue. + e.mu.Lock() + e.pendingAccepted.Add(1) + e.mu.Unlock() + defer e.pendingAccepted.Done() + + // Drop the lock before notifying to avoid deadlock in user-specified + // callbacks. + delivered := func() bool { + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + for { + if e.accepted == (accepted{}) { + // If the listener has transitioned out of the listen state (accepted + // is the zero value), the new endpoint is reset instead. + return false + } + if e.accepted.acceptQueueIsFullLocked() { + e.acceptCond.Wait() + continue + } + + e.accepted.endpoints.PushBack(h.ep) + atomic.AddInt32(&e.synRcvdCount, -1) + return true + } + }() + + if delivered { + e.waiterQueue.Notify(waiter.ReadableEvents) + } else { + h.ep.notifyProtocolGoroutine(notifyReset) + } }() return nil @@ -544,11 +540,15 @@ func (e *endpoint) synRcvdBacklogFull() bool { func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() - full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap + full := e.accepted.acceptQueueIsFullLocked() e.acceptMu.Unlock() return full } +func (a *accepted) acceptQueueIsFullLocked() bool { + return a.endpoints.Len() == a.cap +} + // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. // @@ -627,12 +627,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil case s.flags.Contains(header.TCPFlagAck): - if e.acceptQueueIsFull() { + // Keep hold of acceptMu until the new endpoint is in the accept queue (or + // if there is an error), to guarantee that we will keep our spot in the + // queue even if another handshake from the syn queue completes. + e.acceptMu.Lock() + if e.accepted.acceptQueueIsFullLocked() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be // retransmitted by the sender anyway and we can // complete the connection at the time of retransmit if // the backlog has space. + e.acceptMu.Unlock() e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() e.stack.Stats().DroppedPackets.Increment() @@ -654,6 +659,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err // Validate the cookie. data, ok := ctx.isCookieValid(s.id, iss, irs) if !ok || int(data) >= len(mssTable) { + e.acceptMu.Unlock() e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() e.stack.Stats().DroppedPackets.Increment() @@ -695,6 +701,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{}) if err != nil { + e.acceptMu.Unlock() return err } @@ -706,6 +713,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err if !n.reserveTupleLocked() { n.mu.Unlock() + e.acceptMu.Unlock() n.Close() e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -723,6 +731,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n.boundBindToDevice, ); err != nil { n.mu.Unlock() + e.acceptMu.Unlock() n.Close() e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -755,20 +764,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n.newSegmentWaker.Assert() } - // Do the delivery in a separate goroutine so - // that we don't block the listen loop in case - // the application is slow to accept or stops - // accepting. - // - // NOTE: This won't result in an unbounded - // number of goroutines as we do check before - // entering here that there was at least some - // space available in the backlog. - // Start the protocol goroutine. n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - go e.deliverAccepted(n, true /*withSynCookie*/) + + // Deliver the endpoint to the accept queue. + e.accepted.endpoints.PushBack(n) + e.acceptMu.Unlock() + + e.waiterQueue.Notify(waiter.ReadableEvents) return nil default: diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index d2b8f298f..83c51e855 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -315,7 +315,10 @@ type accepted struct { // belong to one list at a time, and endpoints are already stored in the // dispatcher's list. endpoints list.List `state:".([]*endpoint)"` - cap int + + // cap is the maximum number of endpoints that can be in the accepted endpoint + // list. + cap int } // endpoint represents a TCP endpoint. This struct serves as the interface @@ -333,7 +336,7 @@ type accepted struct { // The following three mutexes can be acquired independent of e.mu but if // acquired with e.mu then e.mu must be acquired first. // -// e.acceptMu -> protects accepted. +// e.acceptMu -> Protects e.accepted. // e.rcvQueueMu -> Protects e.rcvQueue and associated fields. // e.sndQueueMu -> Protects the e.sndQueue and associated fields. // e.lastErrorMu -> Protects the lastError field. @@ -573,6 +576,7 @@ type endpoint struct { // accepted is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. + // +checklocks:acceptMu accepted accepted // The following are only used from the protocol goroutine, and diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index f2e8b3840..381f4474d 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -251,7 +251,9 @@ func (e *endpoint) Resume(s *stack.Stack) { go func() { connectedLoading.Wait() bind() + e.acceptMu.Lock() backlog := e.accepted.cap + e.acceptMu.Unlock() if err := e.Listen(backlog); err != nil { panic("endpoint listening failed: " + err.String()) } diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 3fbbf1423..607182ffd 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -2088,6 +2088,66 @@ TEST_P(SimpleTcpSocketTest, ConnectUnspecifiedAddress) { } } +TEST_P(SimpleTcpSocketTest, OnlyAcknowledgeBacklogConnections) { + // At some point, there was a bug in gVisor where a connection could be + // SYN-ACK'd by the server even if the accept queue was already full. This was + // possible because once the listener would process an ACK, it would move the + // new connection in the accept queue asynchronously. It created an + // opportunity where the listener could process another SYN before completing + // the delivery that would have filled the accept queue. + // + // This test checks that there is no such race. + + std::array, 100> threads; + for (auto& thread : threads) { + thread.emplace([]() { + FileDescriptor bound_s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage bound_addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t bound_addrlen = sizeof(bound_addr); + + ASSERT_THAT(bind(bound_s.get(), AsSockAddr(&bound_addr), bound_addrlen), + SyscallSucceeds()); + + // Start listening. Use a zero backlog to only allow one connection in the + // accept queue. + ASSERT_THAT(listen(bound_s.get(), 0), SyscallSucceeds()); + + // Get the addresses the socket is bound to because the port is chosen by + // the stack. + ASSERT_THAT( + getsockname(bound_s.get(), AsSockAddr(&bound_addr), &bound_addrlen), + SyscallSucceeds()); + + // Establish a connection, but do not accept it. + FileDescriptor connected_s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(connect(connected_s.get(), + reinterpret_cast(&bound_addr), + bound_addrlen), + SyscallSucceeds()); + + // Immediately attempt to establish another connection. Use non blocking + // socket because this is expected to timeout. + FileDescriptor connecting_s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + ASSERT_THAT(connect(connecting_s.get(), + reinterpret_cast(&bound_addr), + bound_addrlen), + SyscallFailsWithErrno(EINPROGRESS)); + + struct pollfd poll_fd = { + .fd = connecting_s.get(), + .events = POLLOUT, + }; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10), + SyscallSucceedsWithValue(0)); + }); + } +} + // Tests that send will return EWOULDBLOCK initially with large buffer and will // succeed after the send buffer size is increased. TEST_P(TcpSocketTest, SendUnblocksOnSendBufferIncrease) { -- cgit v1.2.3