From 0ed53e5e92a065f63864ecfd0dd6d93c12daaef5 Mon Sep 17 00:00:00 2001
From: Arthur Sfez <asfez@google.com>
Date: Tue, 21 Sep 2021 17:27:27 -0700
Subject: Handle Shutdown on connecting tcp socket

Fixes #6495

PiperOrigin-RevId: 398121921
---
 test/syscalls/linux/tcp_socket.cc | 63 +++++++++++++++++++++++++++++++++++++++
 1 file changed, 63 insertions(+)

(limited to 'test/syscalls/linux')

diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 607182ffd..bbcb7e4fd 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -2067,6 +2067,69 @@ TEST_P(SimpleTcpSocketTest, GetSocketAcceptConnWithShutdown) {
   EXPECT_EQ(got, 0);
 }
 
+void ShutdownConnectingSocket(int domain, int shutdown_mode) {
+  FileDescriptor bound_s =
+      ASSERT_NO_ERRNO_AND_VALUE(Socket(domain, SOCK_STREAM, IPPROTO_TCP));
+
+  sockaddr_storage bound_addr =
+      ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(domain));
+  socklen_t bound_addrlen = sizeof(bound_addr);
+
+  ASSERT_THAT(bind(bound_s.get(), AsSockAddr(&bound_addr), bound_addrlen),
+              SyscallSucceeds());
+
+  // Start listening. Use a zero backlog to only allow one connection in the
+  // accept queue.
+  ASSERT_THAT(listen(bound_s.get(), 0), SyscallSucceeds());
+
+  // Get the addresses the socket is bound to because the port is chosen by the
+  // stack.
+  ASSERT_THAT(
+      getsockname(bound_s.get(), AsSockAddr(&bound_addr), &bound_addrlen),
+      SyscallSucceeds());
+
+  // Establish a connection. But do not accept it. That way, subsequent
+  // connections will not get a SYN-ACK because the queue is full.
+  FileDescriptor connected_s =
+      ASSERT_NO_ERRNO_AND_VALUE(Socket(domain, SOCK_STREAM, IPPROTO_TCP));
+  ASSERT_THAT(connect(connected_s.get(),
+                      reinterpret_cast<const struct sockaddr*>(&bound_addr),
+                      bound_addrlen),
+              SyscallSucceeds());
+
+  FileDescriptor connecting_s = ASSERT_NO_ERRNO_AND_VALUE(
+      Socket(domain, SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+  ASSERT_THAT(connect(connecting_s.get(),
+                      reinterpret_cast<const struct sockaddr*>(&bound_addr),
+                      bound_addrlen),
+              SyscallFailsWithErrno(EINPROGRESS));
+
+  // Now the test: when a connecting socket is shutdown, the socket should enter
+  // an error state.
+  EXPECT_THAT(shutdown(connecting_s.get(), shutdown_mode), SyscallSucceeds());
+
+  // We don't need to specify any events to get POLLHUP or POLLERR because these
+  // are always tracked.
+  struct pollfd poll_fd = {
+      .fd = connecting_s.get(),
+  };
+
+  EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 0), SyscallSucceedsWithValue(1));
+  EXPECT_EQ(poll_fd.revents, POLLHUP | POLLERR);
+}
+
+TEST_P(SimpleTcpSocketTest, ShutdownReadConnectingSocket) {
+  ShutdownConnectingSocket(GetParam(), SHUT_RD);
+}
+
+TEST_P(SimpleTcpSocketTest, ShutdownWriteConnectingSocket) {
+  ShutdownConnectingSocket(GetParam(), SHUT_WR);
+}
+
+TEST_P(SimpleTcpSocketTest, ShutdownReadWriteConnectingSocket) {
+  ShutdownConnectingSocket(GetParam(), SHUT_RDWR);
+}
+
 // Tests that connecting to an unspecified address results in ECONNREFUSED.
 TEST_P(SimpleTcpSocketTest, ConnectUnspecifiedAddress) {
   sockaddr_storage addr;
-- 
cgit v1.2.3