summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/accept.go10
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go2
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc74
4 files changed, 74 insertions, 14 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 7acc7e7b0..63c46b1be 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -511,22 +511,22 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
func (e *endpoint) synRcvdBacklogFull() bool {
e.acceptMu.Lock()
- backlog := e.accepted.cap
+ acceptedCap := e.accepted.cap
e.acceptMu.Unlock()
- // The allocated accepted channel size would always be one greater than the
+ // The capacity of the accepted queue would always be one greater than the
// listen backlog. But, the SYNRCVD connections count is always checked
// against the listen backlog value for Linux parity reason.
// https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
//
// We maintain an equality check here as the synRcvdCount is incremented
// and compared only from a single listener context and the capacity of
- // the accepted channel can only increase by a new listen call.
- return int(atomic.LoadInt32(&e.synRcvdCount)) == backlog-1
+ // the accepted queue can only increase by a new listen call.
+ return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1
}
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
- full := e.accepted.endpoints.Len() == e.accepted.cap
+ full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap
e.acceptMu.Unlock()
return full
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 1060a0a90..9afd2bb7f 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -962,7 +962,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
result = mask
case StateListen:
- // Check if there's anything in the accepted channel.
+ // Check if there's anything in the accepted queue.
if (mask & waiter.ReadableEvents) != 0 {
e.acceptMu.Lock()
if e.accepted.endpoints.Len() != 0 {
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index f51b3ad90..590775434 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -67,7 +67,7 @@ func (e *endpoint) beforeSave() {
e.mu.Lock()
}
if !e.workerRunning {
- // The endpoint must be in acceptedChan or has been just
+ // The endpoint must be in the accepted queue or has been just
// disconnected and closed.
break
}
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 6b7776186..2026e52b0 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -472,8 +472,68 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) {
}
}
-void TestListenWhileConnect(const TestParam& param,
- void (*stopListen)(FileDescriptor&)) {
+void TestHangupDuringConnect(const TestParam& param,
+ void (*hangup)(FileDescriptor&)) {
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), 1), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+
+ // Connect asynchronously and immediately hang up the listener.
+ FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ }
+
+ hangup(listen_fd);
+
+ // Wait for the connection to close.
+ struct pollfd pfd = {
+ .fd = client.get(),
+ };
+ constexpr int kTimeout = 10000;
+ int n = poll(&pfd, 1, kTimeout);
+ ASSERT_GE(n, 0) << strerror(errno);
+ ASSERT_EQ(n, 1);
+ ASSERT_EQ(pfd.revents, POLLHUP | POLLERR);
+ ASSERT_EQ(close(client.release()), 0) << strerror(errno);
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenCloseDuringConnect) {
+ TestHangupDuringConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(close(f.release()), SyscallSucceeds());
+ });
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenShutdownDuringConnect) {
+ TestHangupDuringConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds());
+ });
+}
+
+void TestListenHangupConnectingRead(const TestParam& param,
+ void (*hangup)(FileDescriptor&)) {
TestAddress const& listener = param.listener;
TestAddress const& connector = param.connector;
@@ -526,7 +586,7 @@ void TestListenWhileConnect(const TestParam& param,
EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
}
- stopListen(listen_fd);
+ hangup(listen_fd);
std::array<std::pair<int, int>, 2> sockets = {
std::make_pair(established_client.get(), ECONNRESET),
@@ -546,14 +606,14 @@ void TestListenWhileConnect(const TestParam& param,
}
}
-TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) {
- TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+TEST_P(SocketInetLoopbackTest, TCPListenCloseConnectingRead) {
+ TestListenHangupConnectingRead(GetParam(), [](FileDescriptor& f) {
ASSERT_THAT(close(f.release()), SyscallSucceeds());
});
}
-TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) {
- TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+TEST_P(SocketInetLoopbackTest, TCPListenShutdownConnectingRead) {
+ TestListenHangupConnectingRead(GetParam(), [](FileDescriptor& f) {
ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds());
});
}