diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 2 | ||||
-rw-r--r-- | test/syscalls/linux/socket_inet_loopback.cc | 74 |
4 files changed, 74 insertions, 14 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 7acc7e7b0..63c46b1be 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -511,22 +511,22 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header func (e *endpoint) synRcvdBacklogFull() bool { e.acceptMu.Lock() - backlog := e.accepted.cap + acceptedCap := e.accepted.cap e.acceptMu.Unlock() - // The allocated accepted channel size would always be one greater than the + // The capacity of the accepted queue would always be one greater than the // listen backlog. But, the SYNRCVD connections count is always checked // against the listen backlog value for Linux parity reason. // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 // // We maintain an equality check here as the synRcvdCount is incremented // and compared only from a single listener context and the capacity of - // the accepted channel can only increase by a new listen call. - return int(atomic.LoadInt32(&e.synRcvdCount)) == backlog-1 + // the accepted queue can only increase by a new listen call. + return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1 } func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() - full := e.accepted.endpoints.Len() == e.accepted.cap + full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap e.acceptMu.Unlock() return full } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 1060a0a90..9afd2bb7f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -962,7 +962,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { result = mask case StateListen: - // Check if there's anything in the accepted channel. + // Check if there's anything in the accepted queue. if (mask & waiter.ReadableEvents) != 0 { e.acceptMu.Lock() if e.accepted.endpoints.Len() != 0 { diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index f51b3ad90..590775434 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -67,7 +67,7 @@ func (e *endpoint) beforeSave() { e.mu.Lock() } if !e.workerRunning { - // The endpoint must be in acceptedChan or has been just + // The endpoint must be in the accepted queue or has been just // disconnected and closed. break } diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 6b7776186..2026e52b0 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -472,8 +472,68 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { } } -void TestListenWhileConnect(const TestParam& param, - void (*stopListen)(FileDescriptor&)) { +void TestHangupDuringConnect(const TestParam& param, + void (*hangup)(FileDescriptor&)) { + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), 1), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + + // Connect asynchronously and immediately hang up the listener. + FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + } + + hangup(listen_fd); + + // Wait for the connection to close. + struct pollfd pfd = { + .fd = client.get(), + }; + constexpr int kTimeout = 10000; + int n = poll(&pfd, 1, kTimeout); + ASSERT_GE(n, 0) << strerror(errno); + ASSERT_EQ(n, 1); + ASSERT_EQ(pfd.revents, POLLHUP | POLLERR); + ASSERT_EQ(close(client.release()), 0) << strerror(errno); +} + +TEST_P(SocketInetLoopbackTest, TCPListenCloseDuringConnect) { + TestHangupDuringConnect(GetParam(), [](FileDescriptor& f) { + ASSERT_THAT(close(f.release()), SyscallSucceeds()); + }); +} + +TEST_P(SocketInetLoopbackTest, TCPListenShutdownDuringConnect) { + TestHangupDuringConnect(GetParam(), [](FileDescriptor& f) { + ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds()); + }); +} + +void TestListenHangupConnectingRead(const TestParam& param, + void (*hangup)(FileDescriptor&)) { TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -526,7 +586,7 @@ void TestListenWhileConnect(const TestParam& param, EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); } - stopListen(listen_fd); + hangup(listen_fd); std::array<std::pair<int, int>, 2> sockets = { std::make_pair(established_client.get(), ECONNRESET), @@ -546,14 +606,14 @@ void TestListenWhileConnect(const TestParam& param, } } -TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) { - TestListenWhileConnect(GetParam(), [](FileDescriptor& f) { +TEST_P(SocketInetLoopbackTest, TCPListenCloseConnectingRead) { + TestListenHangupConnectingRead(GetParam(), [](FileDescriptor& f) { ASSERT_THAT(close(f.release()), SyscallSucceeds()); }); } -TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) { - TestListenWhileConnect(GetParam(), [](FileDescriptor& f) { +TEST_P(SocketInetLoopbackTest, TCPListenShutdownConnectingRead) { + TestListenHangupConnectingRead(GetParam(), [](FileDescriptor& f) { ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds()); }); } |