summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/accept.go106
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go8
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go2
-rw-r--r--test/syscalls/linux/tcp_socket.cc60
4 files changed, 123 insertions, 53 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 03c9fafa1..ff0a5df9c 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -401,43 +401,6 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
e.h = nil
}
-// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
-// listener has transitioned out of the listen state (accepted is the zero
-// value), the new endpoint is reset instead.
-func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) {
- e.mu.Lock()
- e.pendingAccepted.Add(1)
- e.mu.Unlock()
- defer e.pendingAccepted.Done()
-
- // Drop the lock before notifying to avoid deadlock in user-specified
- // callbacks.
- delivered := func() bool {
- e.acceptMu.Lock()
- defer e.acceptMu.Unlock()
- for {
- if e.accepted == (accepted{}) {
- return false
- }
- if e.accepted.endpoints.Len() == e.accepted.cap {
- e.acceptCond.Wait()
- continue
- }
-
- e.accepted.endpoints.PushBack(n)
- if !withSynCookie {
- atomic.AddInt32(&e.synRcvdCount, -1)
- }
- return true
- }
- }()
- if delivered {
- e.waiterQueue.Notify(waiter.ReadableEvents)
- } else {
- n.notifyProtocolGoroutine(notifyReset)
- }
-}
-
// propagateInheritableOptionsLocked propagates any options set on the listening
// endpoint to the newly created endpoint.
//
@@ -521,7 +484,40 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header.
ctx.cleanupCompletedHandshake(h)
h.ep.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- e.deliverAccepted(h.ep, false /*withSynCookie*/)
+
+ // Deliver the endpoint to the accept queue.
+ e.mu.Lock()
+ e.pendingAccepted.Add(1)
+ e.mu.Unlock()
+ defer e.pendingAccepted.Done()
+
+ // Drop the lock before notifying to avoid deadlock in user-specified
+ // callbacks.
+ delivered := func() bool {
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ for {
+ if e.accepted == (accepted{}) {
+ // If the listener has transitioned out of the listen state (accepted
+ // is the zero value), the new endpoint is reset instead.
+ return false
+ }
+ if e.accepted.acceptQueueIsFullLocked() {
+ e.acceptCond.Wait()
+ continue
+ }
+
+ e.accepted.endpoints.PushBack(h.ep)
+ atomic.AddInt32(&e.synRcvdCount, -1)
+ return true
+ }
+ }()
+
+ if delivered {
+ e.waiterQueue.Notify(waiter.ReadableEvents)
+ } else {
+ h.ep.notifyProtocolGoroutine(notifyReset)
+ }
}()
return nil
@@ -544,11 +540,15 @@ func (e *endpoint) synRcvdBacklogFull() bool {
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
- full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap
+ full := e.accepted.acceptQueueIsFullLocked()
e.acceptMu.Unlock()
return full
}
+func (a *accepted) acceptQueueIsFullLocked() bool {
+ return a.endpoints.Len() == a.cap
+}
+
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
//
@@ -627,12 +627,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
case s.flags.Contains(header.TCPFlagAck):
- if e.acceptQueueIsFull() {
+ // Keep hold of acceptMu until the new endpoint is in the accept queue (or
+ // if there is an error), to guarantee that we will keep our spot in the
+ // queue even if another handshake from the syn queue completes.
+ e.acceptMu.Lock()
+ if e.accepted.acceptQueueIsFullLocked() {
// Silently drop the ack as the application can't accept
// the connection at this point. The ack will be
// retransmitted by the sender anyway and we can
// complete the connection at the time of retransmit if
// the backlog has space.
+ e.acceptMu.Unlock()
e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
@@ -654,6 +659,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
// Validate the cookie.
data, ok := ctx.isCookieValid(s.id, iss, irs)
if !ok || int(data) >= len(mssTable) {
+ e.acceptMu.Unlock()
e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
e.stack.Stats().DroppedPackets.Increment()
@@ -695,6 +701,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{})
if err != nil {
+ e.acceptMu.Unlock()
return err
}
@@ -706,6 +713,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
if !n.reserveTupleLocked() {
n.mu.Unlock()
+ e.acceptMu.Unlock()
n.Close()
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -723,6 +731,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n.boundBindToDevice,
); err != nil {
n.mu.Unlock()
+ e.acceptMu.Unlock()
n.Close()
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -755,20 +764,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n.newSegmentWaker.Assert()
}
- // Do the delivery in a separate goroutine so
- // that we don't block the listen loop in case
- // the application is slow to accept or stops
- // accepting.
- //
- // NOTE: This won't result in an unbounded
- // number of goroutines as we do check before
- // entering here that there was at least some
- // space available in the backlog.
-
// Start the protocol goroutine.
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- go e.deliverAccepted(n, true /*withSynCookie*/)
+
+ // Deliver the endpoint to the accept queue.
+ e.accepted.endpoints.PushBack(n)
+ e.acceptMu.Unlock()
+
+ e.waiterQueue.Notify(waiter.ReadableEvents)
return nil
default:
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index d2b8f298f..83c51e855 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -315,7 +315,10 @@ type accepted struct {
// belong to one list at a time, and endpoints are already stored in the
// dispatcher's list.
endpoints list.List `state:".([]*endpoint)"`
- cap int
+
+ // cap is the maximum number of endpoints that can be in the accepted endpoint
+ // list.
+ cap int
}
// endpoint represents a TCP endpoint. This struct serves as the interface
@@ -333,7 +336,7 @@ type accepted struct {
// The following three mutexes can be acquired independent of e.mu but if
// acquired with e.mu then e.mu must be acquired first.
//
-// e.acceptMu -> protects accepted.
+// e.acceptMu -> Protects e.accepted.
// e.rcvQueueMu -> Protects e.rcvQueue and associated fields.
// e.sndQueueMu -> Protects the e.sndQueue and associated fields.
// e.lastErrorMu -> Protects the lastError field.
@@ -573,6 +576,7 @@ type endpoint struct {
// accepted is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
+ // +checklocks:acceptMu
accepted accepted
// The following are only used from the protocol goroutine, and
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index f2e8b3840..381f4474d 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -251,7 +251,9 @@ func (e *endpoint) Resume(s *stack.Stack) {
go func() {
connectedLoading.Wait()
bind()
+ e.acceptMu.Lock()
backlog := e.accepted.cap
+ e.acceptMu.Unlock()
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
}
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 3fbbf1423..607182ffd 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -2088,6 +2088,66 @@ TEST_P(SimpleTcpSocketTest, ConnectUnspecifiedAddress) {
}
}
+TEST_P(SimpleTcpSocketTest, OnlyAcknowledgeBacklogConnections) {
+ // At some point, there was a bug in gVisor where a connection could be
+ // SYN-ACK'd by the server even if the accept queue was already full. This was
+ // possible because once the listener would process an ACK, it would move the
+ // new connection in the accept queue asynchronously. It created an
+ // opportunity where the listener could process another SYN before completing
+ // the delivery that would have filled the accept queue.
+ //
+ // This test checks that there is no such race.
+
+ std::array<std::optional<ScopedThread>, 100> threads;
+ for (auto& thread : threads) {
+ thread.emplace([]() {
+ FileDescriptor bound_s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage bound_addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t bound_addrlen = sizeof(bound_addr);
+
+ ASSERT_THAT(bind(bound_s.get(), AsSockAddr(&bound_addr), bound_addrlen),
+ SyscallSucceeds());
+
+ // Start listening. Use a zero backlog to only allow one connection in the
+ // accept queue.
+ ASSERT_THAT(listen(bound_s.get(), 0), SyscallSucceeds());
+
+ // Get the addresses the socket is bound to because the port is chosen by
+ // the stack.
+ ASSERT_THAT(
+ getsockname(bound_s.get(), AsSockAddr(&bound_addr), &bound_addrlen),
+ SyscallSucceeds());
+
+ // Establish a connection, but do not accept it.
+ FileDescriptor connected_s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(connect(connected_s.get(),
+ reinterpret_cast<const struct sockaddr*>(&bound_addr),
+ bound_addrlen),
+ SyscallSucceeds());
+
+ // Immediately attempt to establish another connection. Use non blocking
+ // socket because this is expected to timeout.
+ FileDescriptor connecting_s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ ASSERT_THAT(connect(connecting_s.get(),
+ reinterpret_cast<const struct sockaddr*>(&bound_addr),
+ bound_addrlen),
+ SyscallFailsWithErrno(EINPROGRESS));
+
+ struct pollfd poll_fd = {
+ .fd = connecting_s.get(),
+ .events = POLLOUT,
+ };
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10),
+ SyscallSucceedsWithValue(0));
+ });
+ }
+}
+
// Tests that send will return EWOULDBLOCK initially with large buffer and will
// succeed after the send buffer size is increased.
TEST_P(TcpSocketTest, SendUnblocksOnSendBufferIncrease) {