diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 106 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 2 | ||||
-rw-r--r-- | test/syscalls/linux/tcp_socket.cc | 60 |
4 files changed, 123 insertions, 53 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 03c9fafa1..ff0a5df9c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -401,43 +401,6 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { e.h = nil } -// deliverAccepted delivers the newly-accepted endpoint to the listener. If the -// listener has transitioned out of the listen state (accepted is the zero -// value), the new endpoint is reset instead. -func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) { - e.mu.Lock() - e.pendingAccepted.Add(1) - e.mu.Unlock() - defer e.pendingAccepted.Done() - - // Drop the lock before notifying to avoid deadlock in user-specified - // callbacks. - delivered := func() bool { - e.acceptMu.Lock() - defer e.acceptMu.Unlock() - for { - if e.accepted == (accepted{}) { - return false - } - if e.accepted.endpoints.Len() == e.accepted.cap { - e.acceptCond.Wait() - continue - } - - e.accepted.endpoints.PushBack(n) - if !withSynCookie { - atomic.AddInt32(&e.synRcvdCount, -1) - } - return true - } - }() - if delivered { - e.waiterQueue.Notify(waiter.ReadableEvents) - } else { - n.notifyProtocolGoroutine(notifyReset) - } -} - // propagateInheritableOptionsLocked propagates any options set on the listening // endpoint to the newly created endpoint. // @@ -521,7 +484,40 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header. ctx.cleanupCompletedHandshake(h) h.ep.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - e.deliverAccepted(h.ep, false /*withSynCookie*/) + + // Deliver the endpoint to the accept queue. + e.mu.Lock() + e.pendingAccepted.Add(1) + e.mu.Unlock() + defer e.pendingAccepted.Done() + + // Drop the lock before notifying to avoid deadlock in user-specified + // callbacks. + delivered := func() bool { + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + for { + if e.accepted == (accepted{}) { + // If the listener has transitioned out of the listen state (accepted + // is the zero value), the new endpoint is reset instead. + return false + } + if e.accepted.acceptQueueIsFullLocked() { + e.acceptCond.Wait() + continue + } + + e.accepted.endpoints.PushBack(h.ep) + atomic.AddInt32(&e.synRcvdCount, -1) + return true + } + }() + + if delivered { + e.waiterQueue.Notify(waiter.ReadableEvents) + } else { + h.ep.notifyProtocolGoroutine(notifyReset) + } }() return nil @@ -544,11 +540,15 @@ func (e *endpoint) synRcvdBacklogFull() bool { func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() - full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap + full := e.accepted.acceptQueueIsFullLocked() e.acceptMu.Unlock() return full } +func (a *accepted) acceptQueueIsFullLocked() bool { + return a.endpoints.Len() == a.cap +} + // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. // @@ -627,12 +627,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil case s.flags.Contains(header.TCPFlagAck): - if e.acceptQueueIsFull() { + // Keep hold of acceptMu until the new endpoint is in the accept queue (or + // if there is an error), to guarantee that we will keep our spot in the + // queue even if another handshake from the syn queue completes. + e.acceptMu.Lock() + if e.accepted.acceptQueueIsFullLocked() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be // retransmitted by the sender anyway and we can // complete the connection at the time of retransmit if // the backlog has space. + e.acceptMu.Unlock() e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() e.stack.Stats().DroppedPackets.Increment() @@ -654,6 +659,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err // Validate the cookie. data, ok := ctx.isCookieValid(s.id, iss, irs) if !ok || int(data) >= len(mssTable) { + e.acceptMu.Unlock() e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() e.stack.Stats().DroppedPackets.Increment() @@ -695,6 +701,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{}) if err != nil { + e.acceptMu.Unlock() return err } @@ -706,6 +713,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err if !n.reserveTupleLocked() { n.mu.Unlock() + e.acceptMu.Unlock() n.Close() e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -723,6 +731,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n.boundBindToDevice, ); err != nil { n.mu.Unlock() + e.acceptMu.Unlock() n.Close() e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -755,20 +764,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err n.newSegmentWaker.Assert() } - // Do the delivery in a separate goroutine so - // that we don't block the listen loop in case - // the application is slow to accept or stops - // accepting. - // - // NOTE: This won't result in an unbounded - // number of goroutines as we do check before - // entering here that there was at least some - // space available in the backlog. - // Start the protocol goroutine. n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - go e.deliverAccepted(n, true /*withSynCookie*/) + + // Deliver the endpoint to the accept queue. + e.accepted.endpoints.PushBack(n) + e.acceptMu.Unlock() + + e.waiterQueue.Notify(waiter.ReadableEvents) return nil default: diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index d2b8f298f..83c51e855 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -315,7 +315,10 @@ type accepted struct { // belong to one list at a time, and endpoints are already stored in the // dispatcher's list. endpoints list.List `state:".([]*endpoint)"` - cap int + + // cap is the maximum number of endpoints that can be in the accepted endpoint + // list. + cap int } // endpoint represents a TCP endpoint. This struct serves as the interface @@ -333,7 +336,7 @@ type accepted struct { // The following three mutexes can be acquired independent of e.mu but if // acquired with e.mu then e.mu must be acquired first. // -// e.acceptMu -> protects accepted. +// e.acceptMu -> Protects e.accepted. // e.rcvQueueMu -> Protects e.rcvQueue and associated fields. // e.sndQueueMu -> Protects the e.sndQueue and associated fields. // e.lastErrorMu -> Protects the lastError field. @@ -573,6 +576,7 @@ type endpoint struct { // accepted is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. + // +checklocks:acceptMu accepted accepted // The following are only used from the protocol goroutine, and diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index f2e8b3840..381f4474d 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -251,7 +251,9 @@ func (e *endpoint) Resume(s *stack.Stack) { go func() { connectedLoading.Wait() bind() + e.acceptMu.Lock() backlog := e.accepted.cap + e.acceptMu.Unlock() if err := e.Listen(backlog); err != nil { panic("endpoint listening failed: " + err.String()) } diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 3fbbf1423..607182ffd 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -2088,6 +2088,66 @@ TEST_P(SimpleTcpSocketTest, ConnectUnspecifiedAddress) { } } +TEST_P(SimpleTcpSocketTest, OnlyAcknowledgeBacklogConnections) { + // At some point, there was a bug in gVisor where a connection could be + // SYN-ACK'd by the server even if the accept queue was already full. This was + // possible because once the listener would process an ACK, it would move the + // new connection in the accept queue asynchronously. It created an + // opportunity where the listener could process another SYN before completing + // the delivery that would have filled the accept queue. + // + // This test checks that there is no such race. + + std::array<std::optional<ScopedThread>, 100> threads; + for (auto& thread : threads) { + thread.emplace([]() { + FileDescriptor bound_s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage bound_addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t bound_addrlen = sizeof(bound_addr); + + ASSERT_THAT(bind(bound_s.get(), AsSockAddr(&bound_addr), bound_addrlen), + SyscallSucceeds()); + + // Start listening. Use a zero backlog to only allow one connection in the + // accept queue. + ASSERT_THAT(listen(bound_s.get(), 0), SyscallSucceeds()); + + // Get the addresses the socket is bound to because the port is chosen by + // the stack. + ASSERT_THAT( + getsockname(bound_s.get(), AsSockAddr(&bound_addr), &bound_addrlen), + SyscallSucceeds()); + + // Establish a connection, but do not accept it. + FileDescriptor connected_s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(connect(connected_s.get(), + reinterpret_cast<const struct sockaddr*>(&bound_addr), + bound_addrlen), + SyscallSucceeds()); + + // Immediately attempt to establish another connection. Use non blocking + // socket because this is expected to timeout. + FileDescriptor connecting_s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + ASSERT_THAT(connect(connecting_s.get(), + reinterpret_cast<const struct sockaddr*>(&bound_addr), + bound_addrlen), + SyscallFailsWithErrno(EINPROGRESS)); + + struct pollfd poll_fd = { + .fd = connecting_s.get(), + .events = POLLOUT, + }; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10), + SyscallSucceedsWithValue(0)); + }); + } +} + // Tests that send will return EWOULDBLOCK initially with large buffer and will // succeed after the send buffer size is increased. TEST_P(TcpSocketTest, SendUnblocksOnSendBufferIncrease) { |