summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMithun Iyer <iyerm@google.com>2021-04-13 00:56:32 -0700
committergVisor bot <gvisor-bot@google.com>2021-04-13 00:58:56 -0700
commit326394b79a62061e3e239ac104c151ca13647439 (patch)
treea67047c75a10c8c8bc81d7261767407ddfffc683
parente5f58e89bbd376469073c749592d0fb0e3b4c6cb (diff)
Fix listener close, client connect race
Fix a race where the ACK completing the handshake can be dropped by a closing listener without RST to the peer. The listener close would reset the accepted queue and that causes the connecting endpoint in SYNRCVD state to drop the ACK thinking the queue if filled up. PiperOrigin-RevId: 368165509
-rw-r--r--pkg/tcpip/transport/tcp/accept.go10
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go2
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc74
4 files changed, 74 insertions, 14 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 7acc7e7b0..63c46b1be 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -511,22 +511,22 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
func (e *endpoint) synRcvdBacklogFull() bool {
e.acceptMu.Lock()
- backlog := e.accepted.cap
+ acceptedCap := e.accepted.cap
e.acceptMu.Unlock()
- // The allocated accepted channel size would always be one greater than the
+ // The capacity of the accepted queue would always be one greater than the
// listen backlog. But, the SYNRCVD connections count is always checked
// against the listen backlog value for Linux parity reason.
// https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
//
// We maintain an equality check here as the synRcvdCount is incremented
// and compared only from a single listener context and the capacity of
- // the accepted channel can only increase by a new listen call.
- return int(atomic.LoadInt32(&e.synRcvdCount)) == backlog-1
+ // the accepted queue can only increase by a new listen call.
+ return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1
}
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
- full := e.accepted.endpoints.Len() == e.accepted.cap
+ full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap
e.acceptMu.Unlock()
return full
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 1060a0a90..9afd2bb7f 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -962,7 +962,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
result = mask
case StateListen:
- // Check if there's anything in the accepted channel.
+ // Check if there's anything in the accepted queue.
if (mask & waiter.ReadableEvents) != 0 {
e.acceptMu.Lock()
if e.accepted.endpoints.Len() != 0 {
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index f51b3ad90..590775434 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -67,7 +67,7 @@ func (e *endpoint) beforeSave() {
e.mu.Lock()
}
if !e.workerRunning {
- // The endpoint must be in acceptedChan or has been just
+ // The endpoint must be in the accepted queue or has been just
// disconnected and closed.
break
}
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 6b7776186..2026e52b0 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -472,8 +472,68 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) {
}
}
-void TestListenWhileConnect(const TestParam& param,
- void (*stopListen)(FileDescriptor&)) {
+void TestHangupDuringConnect(const TestParam& param,
+ void (*hangup)(FileDescriptor&)) {
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), 1), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+
+ // Connect asynchronously and immediately hang up the listener.
+ FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ }
+
+ hangup(listen_fd);
+
+ // Wait for the connection to close.
+ struct pollfd pfd = {
+ .fd = client.get(),
+ };
+ constexpr int kTimeout = 10000;
+ int n = poll(&pfd, 1, kTimeout);
+ ASSERT_GE(n, 0) << strerror(errno);
+ ASSERT_EQ(n, 1);
+ ASSERT_EQ(pfd.revents, POLLHUP | POLLERR);
+ ASSERT_EQ(close(client.release()), 0) << strerror(errno);
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenCloseDuringConnect) {
+ TestHangupDuringConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(close(f.release()), SyscallSucceeds());
+ });
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenShutdownDuringConnect) {
+ TestHangupDuringConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds());
+ });
+}
+
+void TestListenHangupConnectingRead(const TestParam& param,
+ void (*hangup)(FileDescriptor&)) {
TestAddress const& listener = param.listener;
TestAddress const& connector = param.connector;
@@ -526,7 +586,7 @@ void TestListenWhileConnect(const TestParam& param,
EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
}
- stopListen(listen_fd);
+ hangup(listen_fd);
std::array<std::pair<int, int>, 2> sockets = {
std::make_pair(established_client.get(), ECONNRESET),
@@ -546,14 +606,14 @@ void TestListenWhileConnect(const TestParam& param,
}
}
-TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) {
- TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+TEST_P(SocketInetLoopbackTest, TCPListenCloseConnectingRead) {
+ TestListenHangupConnectingRead(GetParam(), [](FileDescriptor& f) {
ASSERT_THAT(close(f.release()), SyscallSucceeds());
});
}
-TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) {
- TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+TEST_P(SocketInetLoopbackTest, TCPListenShutdownConnectingRead) {
+ TestListenHangupConnectingRead(GetParam(), [](FileDescriptor& f) {
ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds());
});
}