diff options
author | Bhasker Hariharan <bhaskerh@google.com> | 2021-06-22 23:38:37 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-06-22 23:41:29 -0700 |
commit | e5fe488b22734e798df760d9646c6b1c5f25c207 (patch) | |
tree | 8e3ae3c5a51cae94596e7a63921ccd4604e3b703 | |
parent | 179ed309f4eaf424c078dba4688eef2731e6649c (diff) |
Wake up Writers when tcp socket is shutdown for writes.
PiperOrigin-RevId: 380967023
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 6 | ||||
-rw-r--r-- | test/syscalls/linux/tcp_socket.cc | 56 |
2 files changed, 62 insertions, 0 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a27e2110b..242e6b7f8 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2372,6 +2372,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { e.notifyProtocolGoroutine(notifyTickleWorker) return nil } + // Wake up any readers that maybe waiting for the stream to become + // readable. + e.waiterQueue.Notify(waiter.ReadableEvents) } // Close for write. @@ -2394,6 +2397,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { e.sndQueueInfo.SndClosed = true e.sndQueueInfo.sndQueueMu.Unlock() e.handleClose() + // Wake up any writers that maybe waiting for the stream to become + // writable. + e.waiterQueue.Notify(waiter.WritableEvents) } return nil diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 5bfdecc79..183819faf 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1182,6 +1182,62 @@ TEST_P(SimpleTcpSocketTest, SelfConnectSend) { EXPECT_THAT(shutdown(s.get(), SHUT_WR), SyscallSucceedsWithValue(0)); } +TEST_P(SimpleTcpSocketTest, SelfConnectSendShutdownWrite) { + // Initialize address to the loopback one. + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + const FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds()); + // Get the bound port. + ASSERT_THAT(getsockname(s.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), + SyscallSucceeds()); + + // Write enough data to fill send and receive buffers. + size_t write_size = 24 << 20; // 24 MiB. + std::vector<char> writebuf(write_size); + + ScopedThread t([&s]() { + absl::SleepFor(absl::Milliseconds(250)); + ASSERT_THAT(shutdown(s.get(), SHUT_WR), SyscallSucceeds()); + }); + + // Try to send the whole thing. + int n; + ASSERT_THAT(n = SendFd(s.get(), writebuf.data(), writebuf.size(), 0), + SyscallFailsWithErrno(EPIPE)); +} + +TEST_P(SimpleTcpSocketTest, SelfConnectRecvShutdownRead) { + // Initialize address to the loopback one. + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + const FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds()); + // Get the bound port. + ASSERT_THAT(getsockname(s.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), + SyscallSucceeds()); + + ScopedThread t([&s]() { + absl::SleepFor(absl::Milliseconds(250)); + ASSERT_THAT(shutdown(s.get(), SHUT_RD), SyscallSucceeds()); + }); + + char buf[1]; + EXPECT_THAT(recv(s.get(), buf, 0, 0), SyscallSucceedsWithValue(0)); +} + void NonBlockingConnect(int family, int16_t pollMask) { const FileDescriptor listener = ASSERT_NO_ERRNO_AND_VALUE(Socket(family, SOCK_STREAM, IPPROTO_TCP)); |