summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorBhasker Hariharan <bhaskerh@google.com>2021-06-22 23:38:37 -0700
committergVisor bot <gvisor-bot@google.com>2021-06-22 23:41:29 -0700
commite5fe488b22734e798df760d9646c6b1c5f25c207 (patch)
tree8e3ae3c5a51cae94596e7a63921ccd4604e3b703
parent179ed309f4eaf424c078dba4688eef2731e6649c (diff)
Wake up Writers when tcp socket is shutdown for writes.
PiperOrigin-RevId: 380967023
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go6
-rw-r--r--test/syscalls/linux/tcp_socket.cc56
2 files changed, 62 insertions, 0 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a27e2110b..242e6b7f8 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2372,6 +2372,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
e.notifyProtocolGoroutine(notifyTickleWorker)
return nil
}
+ // Wake up any readers that maybe waiting for the stream to become
+ // readable.
+ e.waiterQueue.Notify(waiter.ReadableEvents)
}
// Close for write.
@@ -2394,6 +2397,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
e.sndQueueInfo.SndClosed = true
e.sndQueueInfo.sndQueueMu.Unlock()
e.handleClose()
+ // Wake up any writers that maybe waiting for the stream to become
+ // writable.
+ e.waiterQueue.Notify(waiter.WritableEvents)
}
return nil
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 5bfdecc79..183819faf 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -1182,6 +1182,62 @@ TEST_P(SimpleTcpSocketTest, SelfConnectSend) {
EXPECT_THAT(shutdown(s.get(), SHUT_WR), SyscallSucceedsWithValue(0));
}
+TEST_P(SimpleTcpSocketTest, SelfConnectSendShutdownWrite) {
+ // Initialize address to the loopback one.
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ const FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(bind(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds());
+ // Get the bound port.
+ ASSERT_THAT(getsockname(s.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen),
+ SyscallSucceeds());
+
+ // Write enough data to fill send and receive buffers.
+ size_t write_size = 24 << 20; // 24 MiB.
+ std::vector<char> writebuf(write_size);
+
+ ScopedThread t([&s]() {
+ absl::SleepFor(absl::Milliseconds(250));
+ ASSERT_THAT(shutdown(s.get(), SHUT_WR), SyscallSucceeds());
+ });
+
+ // Try to send the whole thing.
+ int n;
+ ASSERT_THAT(n = SendFd(s.get(), writebuf.data(), writebuf.size(), 0),
+ SyscallFailsWithErrno(EPIPE));
+}
+
+TEST_P(SimpleTcpSocketTest, SelfConnectRecvShutdownRead) {
+ // Initialize address to the loopback one.
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ const FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(bind(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds());
+ // Get the bound port.
+ ASSERT_THAT(getsockname(s.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen),
+ SyscallSucceeds());
+
+ ScopedThread t([&s]() {
+ absl::SleepFor(absl::Milliseconds(250));
+ ASSERT_THAT(shutdown(s.get(), SHUT_RD), SyscallSucceeds());
+ });
+
+ char buf[1];
+ EXPECT_THAT(recv(s.get(), buf, 0, 0), SyscallSucceedsWithValue(0));
+}
+
void NonBlockingConnect(int family, int16_t pollMask) {
const FileDescriptor listener =
ASSERT_NO_ERRNO_AND_VALUE(Socket(family, SOCK_STREAM, IPPROTO_TCP));