summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMithun Iyer <iyerm@google.com>2021-03-16 15:06:26 -0700
committergVisor bot <gvisor-bot@google.com>2021-03-16 15:08:09 -0700
commit5eede4e7563e245a685d6529dffddbf9c3a53f50 (patch)
treee18e4bb3a03dd08a70e7176a24a58969edc27129
parent607a1e481c276c8ab0c3e194ed04b38bc07b71b6 (diff)
Fix a race with synRcvdCount and accept
There is a race in handling new incoming connections on a listening endpoint that causes the endpoint to reply to more incoming SYNs than what is permitted by the listen backlog. The race occurs when there is a successful passive connection handshake and the synRcvdCount counter is decremented, followed by the endpoint delivered to the accept queue. In the window of time between synRcvdCount decrementing and the endpoint being enqueued for accept, new incoming SYNs can be handled without honoring the listen backlog value, as the backlog could be perceived not full. Fixes #5637 PiperOrigin-RevId: 363279372
-rw-r--r--pkg/tcpip/transport/tcp/accept.go25
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go4
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc111
3 files changed, 118 insertions, 22 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 3b574837c..0a2f3291c 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -20,6 +20,7 @@ import (
"fmt"
"hash"
"io"
+ "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/rand"
@@ -390,7 +391,7 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
// endpoint has transitioned out of the listen state (acceptedChan is nil),
// the new endpoint is closed instead.
-func (e *endpoint) deliverAccepted(n *endpoint) {
+func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) {
e.mu.Lock()
e.pendingAccepted.Add(1)
e.mu.Unlock()
@@ -405,6 +406,9 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
}
select {
case e.acceptedChan <- n:
+ if !withSynCookie {
+ atomic.AddInt32(&e.synRcvdCount, -1)
+ }
e.acceptMu.Unlock()
e.waiterQueue.Notify(waiter.EventIn)
return
@@ -476,7 +480,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- e.synRcvdCount--
+ atomic.AddInt32(&e.synRcvdCount, -1)
return err
}
@@ -486,18 +490,13 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
ctx.cleanupFailedHandshake(h)
- e.mu.Lock()
- e.synRcvdCount--
- e.mu.Unlock()
+ atomic.AddInt32(&e.synRcvdCount, -1)
return
}
ctx.cleanupCompletedHandshake(h)
- e.mu.Lock()
- e.synRcvdCount--
- e.mu.Unlock()
h.ep.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- e.deliverAccepted(h.ep)
+ e.deliverAccepted(h.ep, false /*withSynCookie*/)
}() // S/R-SAFE: synRcvdCount is the barrier.
return nil
@@ -505,17 +504,17 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
func (e *endpoint) incSynRcvdCount() bool {
e.acceptMu.Lock()
- canInc := e.synRcvdCount < cap(e.acceptedChan)
+ canInc := int(atomic.LoadInt32(&e.synRcvdCount)) < cap(e.acceptedChan)
e.acceptMu.Unlock()
if canInc {
- e.synRcvdCount++
+ atomic.AddInt32(&e.synRcvdCount, 1)
}
return canInc
}
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
- full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan)
+ full := len(e.acceptedChan)+int(atomic.LoadInt32(&e.synRcvdCount)) >= cap(e.acceptedChan)
e.acceptMu.Unlock()
return full
}
@@ -737,7 +736,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
// Start the protocol goroutine.
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- go e.deliverAccepted(n)
+ go e.deliverAccepted(n, true /*withSynCookie*/)
return nil
default:
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 129f36d11..43d344350 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -532,8 +532,8 @@ type endpoint struct {
segmentQueue segmentQueue `state:"wait"`
// synRcvdCount is the number of connections for this endpoint that are
- // in SYN-RCVD state.
- synRcvdCount int
+ // in SYN-RCVD state; this is only accessed atomically.
+ synRcvdCount int32
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 54b45b075..597b5bcb1 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -490,7 +490,11 @@ void TestListenWhileConnect(const TestParam& param,
TestAddress const& connector = param.connector;
constexpr int kBacklog = 2;
- constexpr int kClients = kBacklog + 1;
+ // Linux completes one more connection than the listen backlog argument.
+ // To ensure that there is at least one client connection that stays in
+ // connecting state, keep 2 more client connections than the listen backlog.
+ // gVisor differs in this behavior though, gvisor.dev/issue/3153.
+ constexpr int kClients = kBacklog + 2;
// Create the listening socket.
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
@@ -527,7 +531,7 @@ void TestListenWhileConnect(const TestParam& param,
for (auto& client : clients) {
constexpr int kTimeout = 10000;
- struct pollfd pfd = {
+ pollfd pfd = {
.fd = client.get(),
.events = POLLIN,
};
@@ -543,6 +547,10 @@ void TestListenWhileConnect(const TestParam& param,
ASSERT_THAT(read(client.get(), &c, sizeof(c)),
AnyOf(SyscallFailsWithErrno(ECONNRESET),
SyscallFailsWithErrno(ECONNREFUSED)));
+ // The last client connection would be in connecting (SYN_SENT) state.
+ if (client.get() == clients[kClients - 1].get()) {
+ ASSERT_EQ(errno, ECONNREFUSED) << strerror(errno);
+ }
}
}
@@ -598,7 +606,7 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) {
connector.addr_len);
if (ret != 0) {
EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
- struct pollfd pfd = {
+ pollfd pfd = {
.fd = conn_fd.get(),
.events = POLLOUT,
};
@@ -623,6 +631,95 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) {
}
}
+// Test if the stack completes atmost listen backlog number of client
+// connections. It exercises the path of the stack that enqueues completed
+// connections to accept queue vs new incoming SYNs.
+TEST_P(SocketInetLoopbackTest, TCPConnectBacklog_NoRandomSave) {
+ const auto& param = GetParam();
+ const TestAddress& listener = param.listener;
+ const TestAddress& connector = param.connector;
+
+ constexpr int kBacklog = 1;
+ // Keep the number of client connections more than the listen backlog.
+ // Linux completes one more connection than the listen backlog argument.
+ // gVisor differs in this behavior though, gvisor.dev/issue/3153.
+ int kClients = kBacklog + 2;
+ if (IsRunningOnGvisor()) {
+ kClients--;
+ }
+
+ // Run the following test for few iterations to test race between accept queue
+ // getting filled with incoming SYNs.
+ for (int num = 0; num < 10; num++) {
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+
+ std::vector<FileDescriptor> clients;
+ // Issue multiple non-blocking client connects.
+ for (int i = 0; i < kClients; i++) {
+ FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ }
+ clients.push_back(std::move(client));
+ }
+
+ // Now that client connects are issued, wait for the accept queue to get
+ // filled and ensure no new client connection is completed.
+ for (int i = 0; i < kClients; i++) {
+ pollfd pfd = {
+ .fd = clients[i].get(),
+ .events = POLLOUT,
+ };
+ if (i < kClients - 1) {
+ // Poll for client side connection completions with a large timeout.
+ // We cannot poll on the listener side without calling accept as poll
+ // stays level triggered with non-zero accept queue length.
+ //
+ // Client side poll would not guarantee that the completed connection
+ // has been enqueued in to the acccept queue, but the fact that the
+ // listener ACKd the SYN, means that it cannot complete any new incoming
+ // SYNs when it has already ACKd for > backlog number of SYNs.
+ ASSERT_THAT(poll(&pfd, 1, 10000), SyscallSucceedsWithValue(1))
+ << "num=" << num << " i=" << i << " kClients=" << kClients;
+ ASSERT_EQ(pfd.revents, POLLOUT) << "num=" << num << " i=" << i;
+ } else {
+ // Now that we expect accept queue filled up, ensure that the last
+ // client connection never completes with a smaller poll timeout.
+ ASSERT_THAT(poll(&pfd, 1, 1000), SyscallSucceedsWithValue(0))
+ << "num=" << num << " i=" << i;
+ }
+
+ ASSERT_THAT(close(clients[i].release()), SyscallSucceedsWithValue(0))
+ << "num=" << num << " i=" << i;
+ }
+ clients.clear();
+ // We close the listening side and open a new listener. We could instead
+ // drain the accept queue by calling accept() and reuse the listener, but
+ // that is racy as the retransmitted SYNs could get ACKd as we make room in
+ // the accept queue.
+ ASSERT_THAT(close(listen_fd.release()), SyscallSucceedsWithValue(0));
+ }
+}
+
// TCPFinWait2Test creates a pair of connected sockets then closes one end to
// trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local
// IP/port on a new socket and tries to connect. The connect should fail w/
@@ -937,7 +1034,7 @@ void setupTimeWaitClose(const TestAddress* listener,
ASSERT_THAT(shutdown(active_closefd.get(), SHUT_WR), SyscallSucceeds());
{
constexpr int kTimeout = 10000;
- struct pollfd pfd = {
+ pollfd pfd = {
.fd = passive_closefd.get(),
.events = POLLIN,
};
@@ -948,7 +1045,7 @@ void setupTimeWaitClose(const TestAddress* listener,
{
constexpr int kTimeout = 10000;
constexpr int16_t want_events = POLLHUP;
- struct pollfd pfd = {
+ pollfd pfd = {
.fd = active_closefd.get(),
.events = want_events,
};
@@ -1181,7 +1278,7 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) {
// Wait for accept_fd to process the RST.
constexpr int kTimeout = 10000;
- struct pollfd pfd = {
+ pollfd pfd = {
.fd = accept_fd.get(),
.events = POLLIN,
};
@@ -1705,7 +1802,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) {
SyscallSucceedsWithValue(sizeof(i)));
}
- struct pollfd pollfds[kThreadCount];
+ pollfd pollfds[kThreadCount];
for (int i = 0; i < kThreadCount; i++) {
pollfds[i].fd = listener_fds[i].get();
pollfds[i].events = POLLIN;