diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 4 | ||||
-rw-r--r-- | test/syscalls/linux/socket_inet_loopback.cc | 111 |
3 files changed, 118 insertions, 22 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 3b574837c..0a2f3291c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -20,6 +20,7 @@ import ( "fmt" "hash" "io" + "sync/atomic" "time" "gvisor.dev/gvisor/pkg/rand" @@ -390,7 +391,7 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { // deliverAccepted delivers the newly-accepted endpoint to the listener. If the // endpoint has transitioned out of the listen state (acceptedChan is nil), // the new endpoint is closed instead. -func (e *endpoint) deliverAccepted(n *endpoint) { +func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) { e.mu.Lock() e.pendingAccepted.Add(1) e.mu.Unlock() @@ -405,6 +406,9 @@ func (e *endpoint) deliverAccepted(n *endpoint) { } select { case e.acceptedChan <- n: + if !withSynCookie { + atomic.AddInt32(&e.synRcvdCount, -1) + } e.acceptMu.Unlock() e.waiterQueue.Notify(waiter.EventIn) return @@ -476,7 +480,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - e.synRcvdCount-- + atomic.AddInt32(&e.synRcvdCount, -1) return err } @@ -486,18 +490,13 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() ctx.cleanupFailedHandshake(h) - e.mu.Lock() - e.synRcvdCount-- - e.mu.Unlock() + atomic.AddInt32(&e.synRcvdCount, -1) return } ctx.cleanupCompletedHandshake(h) - e.mu.Lock() - e.synRcvdCount-- - e.mu.Unlock() h.ep.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - e.deliverAccepted(h.ep) + e.deliverAccepted(h.ep, false /*withSynCookie*/) }() // S/R-SAFE: synRcvdCount is the barrier. return nil @@ -505,17 +504,17 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header func (e *endpoint) incSynRcvdCount() bool { e.acceptMu.Lock() - canInc := e.synRcvdCount < cap(e.acceptedChan) + canInc := int(atomic.LoadInt32(&e.synRcvdCount)) < cap(e.acceptedChan) e.acceptMu.Unlock() if canInc { - e.synRcvdCount++ + atomic.AddInt32(&e.synRcvdCount, 1) } return canInc } func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() - full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan) + full := len(e.acceptedChan)+int(atomic.LoadInt32(&e.synRcvdCount)) >= cap(e.acceptedChan) e.acceptMu.Unlock() return full } @@ -737,7 +736,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err // Start the protocol goroutine. n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - go e.deliverAccepted(n) + go e.deliverAccepted(n, true /*withSynCookie*/) return nil default: diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 129f36d11..43d344350 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -532,8 +532,8 @@ type endpoint struct { segmentQueue segmentQueue `state:"wait"` // synRcvdCount is the number of connections for this endpoint that are - // in SYN-RCVD state. - synRcvdCount int + // in SYN-RCVD state; this is only accessed atomically. + synRcvdCount int32 // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 54b45b075..597b5bcb1 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -490,7 +490,11 @@ void TestListenWhileConnect(const TestParam& param, TestAddress const& connector = param.connector; constexpr int kBacklog = 2; - constexpr int kClients = kBacklog + 1; + // Linux completes one more connection than the listen backlog argument. + // To ensure that there is at least one client connection that stays in + // connecting state, keep 2 more client connections than the listen backlog. + // gVisor differs in this behavior though, gvisor.dev/issue/3153. + constexpr int kClients = kBacklog + 2; // Create the listening socket. FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( @@ -527,7 +531,7 @@ void TestListenWhileConnect(const TestParam& param, for (auto& client : clients) { constexpr int kTimeout = 10000; - struct pollfd pfd = { + pollfd pfd = { .fd = client.get(), .events = POLLIN, }; @@ -543,6 +547,10 @@ void TestListenWhileConnect(const TestParam& param, ASSERT_THAT(read(client.get(), &c, sizeof(c)), AnyOf(SyscallFailsWithErrno(ECONNRESET), SyscallFailsWithErrno(ECONNREFUSED))); + // The last client connection would be in connecting (SYN_SENT) state. + if (client.get() == clients[kClients - 1].get()) { + ASSERT_EQ(errno, ECONNREFUSED) << strerror(errno); + } } } @@ -598,7 +606,7 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) { connector.addr_len); if (ret != 0) { EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); - struct pollfd pfd = { + pollfd pfd = { .fd = conn_fd.get(), .events = POLLOUT, }; @@ -623,6 +631,95 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) { } } +// Test if the stack completes atmost listen backlog number of client +// connections. It exercises the path of the stack that enqueues completed +// connections to accept queue vs new incoming SYNs. +TEST_P(SocketInetLoopbackTest, TCPConnectBacklog_NoRandomSave) { + const auto& param = GetParam(); + const TestAddress& listener = param.listener; + const TestAddress& connector = param.connector; + + constexpr int kBacklog = 1; + // Keep the number of client connections more than the listen backlog. + // Linux completes one more connection than the listen backlog argument. + // gVisor differs in this behavior though, gvisor.dev/issue/3153. + int kClients = kBacklog + 2; + if (IsRunningOnGvisor()) { + kClients--; + } + + // Run the following test for few iterations to test race between accept queue + // getting filled with incoming SYNs. + for (int num = 0; num < 10; num++) { + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + + std::vector<FileDescriptor> clients; + // Issue multiple non-blocking client connects. + for (int i = 0; i < kClients; i++) { + FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + } + clients.push_back(std::move(client)); + } + + // Now that client connects are issued, wait for the accept queue to get + // filled and ensure no new client connection is completed. + for (int i = 0; i < kClients; i++) { + pollfd pfd = { + .fd = clients[i].get(), + .events = POLLOUT, + }; + if (i < kClients - 1) { + // Poll for client side connection completions with a large timeout. + // We cannot poll on the listener side without calling accept as poll + // stays level triggered with non-zero accept queue length. + // + // Client side poll would not guarantee that the completed connection + // has been enqueued in to the acccept queue, but the fact that the + // listener ACKd the SYN, means that it cannot complete any new incoming + // SYNs when it has already ACKd for > backlog number of SYNs. + ASSERT_THAT(poll(&pfd, 1, 10000), SyscallSucceedsWithValue(1)) + << "num=" << num << " i=" << i << " kClients=" << kClients; + ASSERT_EQ(pfd.revents, POLLOUT) << "num=" << num << " i=" << i; + } else { + // Now that we expect accept queue filled up, ensure that the last + // client connection never completes with a smaller poll timeout. + ASSERT_THAT(poll(&pfd, 1, 1000), SyscallSucceedsWithValue(0)) + << "num=" << num << " i=" << i; + } + + ASSERT_THAT(close(clients[i].release()), SyscallSucceedsWithValue(0)) + << "num=" << num << " i=" << i; + } + clients.clear(); + // We close the listening side and open a new listener. We could instead + // drain the accept queue by calling accept() and reuse the listener, but + // that is racy as the retransmitted SYNs could get ACKd as we make room in + // the accept queue. + ASSERT_THAT(close(listen_fd.release()), SyscallSucceedsWithValue(0)); + } +} + // TCPFinWait2Test creates a pair of connected sockets then closes one end to // trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local // IP/port on a new socket and tries to connect. The connect should fail w/ @@ -937,7 +1034,7 @@ void setupTimeWaitClose(const TestAddress* listener, ASSERT_THAT(shutdown(active_closefd.get(), SHUT_WR), SyscallSucceeds()); { constexpr int kTimeout = 10000; - struct pollfd pfd = { + pollfd pfd = { .fd = passive_closefd.get(), .events = POLLIN, }; @@ -948,7 +1045,7 @@ void setupTimeWaitClose(const TestAddress* listener, { constexpr int kTimeout = 10000; constexpr int16_t want_events = POLLHUP; - struct pollfd pfd = { + pollfd pfd = { .fd = active_closefd.get(), .events = want_events, }; @@ -1181,7 +1278,7 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { // Wait for accept_fd to process the RST. constexpr int kTimeout = 10000; - struct pollfd pfd = { + pollfd pfd = { .fd = accept_fd.get(), .events = POLLIN, }; @@ -1705,7 +1802,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) { SyscallSucceedsWithValue(sizeof(i))); } - struct pollfd pollfds[kThreadCount]; + pollfd pollfds[kThreadCount]; for (int i = 0; i < kThreadCount; i++) { pollfds[i].fd = listener_fds[i].get(); pollfds[i].events = POLLIN; |