diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 19 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ip_tcp_generic.cc | 29 |
3 files changed, 62 insertions, 15 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index c0b785431..09eff5be1 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1199,21 +1199,24 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { switch e.state { case stateConnected: - // Close for write. - if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 { - if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { - // We're fully closed, if we have unread data we need to abort - // the connection with a RST. - e.rcvListMu.Lock() - rcvBufUsed := e.rcvBufUsed - e.rcvListMu.Unlock() - - if rcvBufUsed > 0 { - e.notifyProtocolGoroutine(notifyReset) - return nil - } + // Close for read. + if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { + // Mark read side as closed. + e.rcvListMu.Lock() + e.rcvClosed = true + rcvBufUsed := e.rcvBufUsed + e.rcvListMu.Unlock() + + // If we're fully closed and we have unread data we need to abort + // the connection with a RST. + if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { + e.notifyProtocolGoroutine(notifyReset) + return nil } + } + // Close for write. + if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 { e.sndBufMu.Lock() if e.sndClosed { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index af50ac8af..c5732ad1c 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -687,6 +687,25 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { }) } +func TestShutdownRead(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) + } +} + func TestFullWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 8a222008e..2f45f01ec 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -190,8 +190,7 @@ TEST_P(TCPSocketPairTest, FINSentOnShutdownWrWithUnreadData) { } // This test will verify that when data is received by a socket, even if it's -// not read SHUT_RD will not cause any packets to be generated and data will -// remain in the buffer and can be read later. +// not read SHUT_RD will not cause any packets to be generated. TEST_P(TCPSocketPairTest, ShutdownRdShouldCauseNoPacketsWithUnreadData) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -213,10 +212,36 @@ TEST_P(TCPSocketPairTest, ShutdownRdShouldCauseNoPacketsWithUnreadData) { constexpr int kPollNoResponseTimeoutMs = 3000; ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollNoResponseTimeoutMs), SyscallSucceedsWithValue(0)); // Timeout. +} + +// This test will verify that a socket which has unread data will still allow +// the data to be read after shutting down the read side, and once there is no +// unread data left, then read will return an EOF. +TEST_P(TCPSocketPairTest, ShutdownRdAllowsReadOfReceivedDataBeforeEOF) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char buf[10] = {}; + ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + // Wait until t_ sees the data on its side but don't read it. + struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0}; + constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data. + ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + + // Now shutdown the read end. + ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RD), SyscallSucceeds()); // Even though we did a SHUT_RD on the read end we can still read the data. ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)), SyscallSucceedsWithValue(sizeof(buf))); + + // After reading all of the data, reading the closed read end returns EOF. + ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(0)); } TEST_P(TCPSocketPairTest, ClosedReadNonBlockingSocket) { |