summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/accept.go25
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go4
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc111
3 files changed, 118 insertions, 22 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 3b574837c..0a2f3291c 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -20,6 +20,7 @@ import (
"fmt"
"hash"
"io"
+ "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/rand"
@@ -390,7 +391,7 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
// endpoint has transitioned out of the listen state (acceptedChan is nil),
// the new endpoint is closed instead.
-func (e *endpoint) deliverAccepted(n *endpoint) {
+func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) {
e.mu.Lock()
e.pendingAccepted.Add(1)
e.mu.Unlock()
@@ -405,6 +406,9 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
}
select {
case e.acceptedChan <- n:
+ if !withSynCookie {
+ atomic.AddInt32(&e.synRcvdCount, -1)
+ }
e.acceptMu.Unlock()
e.waiterQueue.Notify(waiter.EventIn)
return
@@ -476,7 +480,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- e.synRcvdCount--
+ atomic.AddInt32(&e.synRcvdCount, -1)
return err
}
@@ -486,18 +490,13 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
ctx.cleanupFailedHandshake(h)
- e.mu.Lock()
- e.synRcvdCount--
- e.mu.Unlock()
+ atomic.AddInt32(&e.synRcvdCount, -1)
return
}
ctx.cleanupCompletedHandshake(h)
- e.mu.Lock()
- e.synRcvdCount--
- e.mu.Unlock()
h.ep.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- e.deliverAccepted(h.ep)
+ e.deliverAccepted(h.ep, false /*withSynCookie*/)
}() // S/R-SAFE: synRcvdCount is the barrier.
return nil
@@ -505,17 +504,17 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
func (e *endpoint) incSynRcvdCount() bool {
e.acceptMu.Lock()
- canInc := e.synRcvdCount < cap(e.acceptedChan)
+ canInc := int(atomic.LoadInt32(&e.synRcvdCount)) < cap(e.acceptedChan)
e.acceptMu.Unlock()
if canInc {
- e.synRcvdCount++
+ atomic.AddInt32(&e.synRcvdCount, 1)
}
return canInc
}
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
- full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan)
+ full := len(e.acceptedChan)+int(atomic.LoadInt32(&e.synRcvdCount)) >= cap(e.acceptedChan)
e.acceptMu.Unlock()
return full
}
@@ -737,7 +736,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
// Start the protocol goroutine.
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- go e.deliverAccepted(n)
+ go e.deliverAccepted(n, true /*withSynCookie*/)
return nil
default:
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 129f36d11..43d344350 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -532,8 +532,8 @@ type endpoint struct {
segmentQueue segmentQueue `state:"wait"`
// synRcvdCount is the number of connections for this endpoint that are
- // in SYN-RCVD state.
- synRcvdCount int
+ // in SYN-RCVD state; this is only accessed atomically.
+ synRcvdCount int32
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 54b45b075..597b5bcb1 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -490,7 +490,11 @@ void TestListenWhileConnect(const TestParam& param,
TestAddress const& connector = param.connector;
constexpr int kBacklog = 2;
- constexpr int kClients = kBacklog + 1;
+ // Linux completes one more connection than the listen backlog argument.
+ // To ensure that there is at least one client connection that stays in
+ // connecting state, keep 2 more client connections than the listen backlog.
+ // gVisor differs in this behavior though, gvisor.dev/issue/3153.
+ constexpr int kClients = kBacklog + 2;
// Create the listening socket.
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
@@ -527,7 +531,7 @@ void TestListenWhileConnect(const TestParam& param,
for (auto& client : clients) {
constexpr int kTimeout = 10000;
- struct pollfd pfd = {
+ pollfd pfd = {
.fd = client.get(),
.events = POLLIN,
};
@@ -543,6 +547,10 @@ void TestListenWhileConnect(const TestParam& param,
ASSERT_THAT(read(client.get(), &c, sizeof(c)),
AnyOf(SyscallFailsWithErrno(ECONNRESET),
SyscallFailsWithErrno(ECONNREFUSED)));
+ // The last client connection would be in connecting (SYN_SENT) state.
+ if (client.get() == clients[kClients - 1].get()) {
+ ASSERT_EQ(errno, ECONNREFUSED) << strerror(errno);
+ }
}
}
@@ -598,7 +606,7 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) {
connector.addr_len);
if (ret != 0) {
EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
- struct pollfd pfd = {
+ pollfd pfd = {
.fd = conn_fd.get(),
.events = POLLOUT,
};
@@ -623,6 +631,95 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) {
}
}
+// Test if the stack completes atmost listen backlog number of client
+// connections. It exercises the path of the stack that enqueues completed
+// connections to accept queue vs new incoming SYNs.
+TEST_P(SocketInetLoopbackTest, TCPConnectBacklog_NoRandomSave) {
+ const auto& param = GetParam();
+ const TestAddress& listener = param.listener;
+ const TestAddress& connector = param.connector;
+
+ constexpr int kBacklog = 1;
+ // Keep the number of client connections more than the listen backlog.
+ // Linux completes one more connection than the listen backlog argument.
+ // gVisor differs in this behavior though, gvisor.dev/issue/3153.
+ int kClients = kBacklog + 2;
+ if (IsRunningOnGvisor()) {
+ kClients--;
+ }
+
+ // Run the following test for few iterations to test race between accept queue
+ // getting filled with incoming SYNs.
+ for (int num = 0; num < 10; num++) {
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+
+ std::vector<FileDescriptor> clients;
+ // Issue multiple non-blocking client connects.
+ for (int i = 0; i < kClients; i++) {
+ FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ }
+ clients.push_back(std::move(client));
+ }
+
+ // Now that client connects are issued, wait for the accept queue to get
+ // filled and ensure no new client connection is completed.
+ for (int i = 0; i < kClients; i++) {
+ pollfd pfd = {
+ .fd = clients[i].get(),
+ .events = POLLOUT,
+ };
+ if (i < kClients - 1) {
+ // Poll for client side connection completions with a large timeout.
+ // We cannot poll on the listener side without calling accept as poll
+ // stays level triggered with non-zero accept queue length.
+ //
+ // Client side poll would not guarantee that the completed connection
+ // has been enqueued in to the acccept queue, but the fact that the
+ // listener ACKd the SYN, means that it cannot complete any new incoming
+ // SYNs when it has already ACKd for > backlog number of SYNs.
+ ASSERT_THAT(poll(&pfd, 1, 10000), SyscallSucceedsWithValue(1))
+ << "num=" << num << " i=" << i << " kClients=" << kClients;
+ ASSERT_EQ(pfd.revents, POLLOUT) << "num=" << num << " i=" << i;
+ } else {
+ // Now that we expect accept queue filled up, ensure that the last
+ // client connection never completes with a smaller poll timeout.
+ ASSERT_THAT(poll(&pfd, 1, 1000), SyscallSucceedsWithValue(0))
+ << "num=" << num << " i=" << i;
+ }
+
+ ASSERT_THAT(close(clients[i].release()), SyscallSucceedsWithValue(0))
+ << "num=" << num << " i=" << i;
+ }
+ clients.clear();
+ // We close the listening side and open a new listener. We could instead
+ // drain the accept queue by calling accept() and reuse the listener, but
+ // that is racy as the retransmitted SYNs could get ACKd as we make room in
+ // the accept queue.
+ ASSERT_THAT(close(listen_fd.release()), SyscallSucceedsWithValue(0));
+ }
+}
+
// TCPFinWait2Test creates a pair of connected sockets then closes one end to
// trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local
// IP/port on a new socket and tries to connect. The connect should fail w/
@@ -937,7 +1034,7 @@ void setupTimeWaitClose(const TestAddress* listener,
ASSERT_THAT(shutdown(active_closefd.get(), SHUT_WR), SyscallSucceeds());
{
constexpr int kTimeout = 10000;
- struct pollfd pfd = {
+ pollfd pfd = {
.fd = passive_closefd.get(),
.events = POLLIN,
};
@@ -948,7 +1045,7 @@ void setupTimeWaitClose(const TestAddress* listener,
{
constexpr int kTimeout = 10000;
constexpr int16_t want_events = POLLHUP;
- struct pollfd pfd = {
+ pollfd pfd = {
.fd = active_closefd.get(),
.events = want_events,
};
@@ -1181,7 +1278,7 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) {
// Wait for accept_fd to process the RST.
constexpr int kTimeout = 10000;
- struct pollfd pfd = {
+ pollfd pfd = {
.fd = accept_fd.get(),
.events = POLLIN,
};
@@ -1705,7 +1802,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) {
SyscallSucceedsWithValue(sizeof(i)));
}
- struct pollfd pollfds[kThreadCount];
+ pollfd pollfds[kThreadCount];
for (int i = 0; i < kThreadCount; i++) {
pollfds[i].fd = listener_fds[i].get();
pollfds[i].events = POLLIN;