summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go29
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go19
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.cc29
3 files changed, 62 insertions, 15 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index c0b785431..09eff5be1 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1199,21 +1199,24 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
switch e.state {
case stateConnected:
- // Close for write.
- if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 {
- if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
- // We're fully closed, if we have unread data we need to abort
- // the connection with a RST.
- e.rcvListMu.Lock()
- rcvBufUsed := e.rcvBufUsed
- e.rcvListMu.Unlock()
-
- if rcvBufUsed > 0 {
- e.notifyProtocolGoroutine(notifyReset)
- return nil
- }
+ // Close for read.
+ if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
+ // Mark read side as closed.
+ e.rcvListMu.Lock()
+ e.rcvClosed = true
+ rcvBufUsed := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+
+ // If we're fully closed and we have unread data we need to abort
+ // the connection with a RST.
+ if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
+ e.notifyProtocolGoroutine(notifyReset)
+ return nil
}
+ }
+ // Close for write.
+ if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 {
e.sndBufMu.Lock()
if e.sndClosed {
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index af50ac8af..c5732ad1c 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -687,6 +687,25 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
})
}
+func TestShutdownRead(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ }
+
+ if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatalf("Shutdown failed: %v", err)
+ }
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
+ }
+}
+
func TestFullWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
index 8a222008e..2f45f01ec 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -190,8 +190,7 @@ TEST_P(TCPSocketPairTest, FINSentOnShutdownWrWithUnreadData) {
}
// This test will verify that when data is received by a socket, even if it's
-// not read SHUT_RD will not cause any packets to be generated and data will
-// remain in the buffer and can be read later.
+// not read SHUT_RD will not cause any packets to be generated.
TEST_P(TCPSocketPairTest, ShutdownRdShouldCauseNoPacketsWithUnreadData) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -213,10 +212,36 @@ TEST_P(TCPSocketPairTest, ShutdownRdShouldCauseNoPacketsWithUnreadData) {
constexpr int kPollNoResponseTimeoutMs = 3000;
ASSERT_THAT(RetryEINTR(poll)(&poll_fd2, 1, kPollNoResponseTimeoutMs),
SyscallSucceedsWithValue(0)); // Timeout.
+}
+
+// This test will verify that a socket which has unread data will still allow
+// the data to be read after shutting down the read side, and once there is no
+// unread data left, then read will return an EOF.
+TEST_P(TCPSocketPairTest, ShutdownRdAllowsReadOfReceivedDataBeforeEOF) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char buf[10] = {};
+ ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Wait until t_ sees the data on its side but don't read it.
+ struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0};
+ constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data.
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+
+ // Now shutdown the read end.
+ ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RD), SyscallSucceeds());
// Even though we did a SHUT_RD on the read end we can still read the data.
ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)),
SyscallSucceedsWithValue(sizeof(buf)));
+
+ // After reading all of the data, reading the closed read end returns EOF.
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+ ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(0));
}
TEST_P(TCPSocketPairTest, ClosedReadNonBlockingSocket) {