summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go29
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go19
2 files changed, 35 insertions, 13 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index c0b785431..09eff5be1 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1199,21 +1199,24 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
switch e.state {
case stateConnected:
- // Close for write.
- if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 {
- if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
- // We're fully closed, if we have unread data we need to abort
- // the connection with a RST.
- e.rcvListMu.Lock()
- rcvBufUsed := e.rcvBufUsed
- e.rcvListMu.Unlock()
-
- if rcvBufUsed > 0 {
- e.notifyProtocolGoroutine(notifyReset)
- return nil
- }
+ // Close for read.
+ if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
+ // Mark read side as closed.
+ e.rcvListMu.Lock()
+ e.rcvClosed = true
+ rcvBufUsed := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+
+ // If we're fully closed and we have unread data we need to abort
+ // the connection with a RST.
+ if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
+ e.notifyProtocolGoroutine(notifyReset)
+ return nil
}
+ }
+ // Close for write.
+ if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 {
e.sndBufMu.Lock()
if e.sndClosed {
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index af50ac8af..c5732ad1c 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -687,6 +687,25 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
})
}
+func TestShutdownRead(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ }
+
+ if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatalf("Shutdown failed: %v", err)
+ }
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
+ }
+}
+
func TestFullWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()