diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/segment_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 75 | ||||
-rw-r--r-- | test/syscalls/linux/tcp_socket.cc | 63 |
5 files changed, 158 insertions, 5 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5d8e18484..80cd07218 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -30,6 +30,10 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// InitialRTO is the initial retransmission timeout. +// https://github.com/torvalds/linux/blob/7c636d4d20f/include/net/tcp.h#L142 +const InitialRTO = time.Second + // maxSegmentsPerWake is the maximum number of segments to process in the main // protocol goroutine per wake-up. Yielding [after this number of segments are // processed] allows other events to be processed as well (e.g., timeouts, @@ -532,7 +536,7 @@ func (h *handshake) complete() tcpip.Error { defer s.Done() // Initialize the resend timer. - timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert) + timer, err := newBackoffTimer(h.ep.stack.Clock(), InitialRTO, MaxRTO, resendWaker.Assert) if err != nil { return err } @@ -578,6 +582,9 @@ func (h *handshake) complete() tcpip.Error { if (n¬ifyClose)|(n¬ifyAbort) != 0 { return &tcpip.ErrAborted{} } + if n¬ifyShutdown != 0 { + return &tcpip.ErrConnectionReset{} + } if n¬ifyDrain != 0 { for !h.ep.segmentQueue.empty() { s := h.ep.segmentQueue.dequeue() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 83c51e855..a3002abf3 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -187,6 +187,8 @@ const ( // say TIME_WAIT. notifyTickleWorker notifyError + // notifyShutdown means that a connecting socket was shutdown. + notifyShutdown ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -2384,6 +2386,18 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.LockUser() defer e.UnlockUser() + + if e.EndpointState().connecting() { + // When calling shutdown(2) on a connecting socket, the endpoint must + // enter the error state. But this logic cannot belong to the shutdownLocked + // method because that method is called during a close(2) (and closing a + // connecting socket is not an error). + e.resetConnectionLocked(&tcpip.ErrConnectionReset{}) + e.notifyProtocolGoroutine(notifyShutdown) + e.waiterQueue.Notify(waiter.WritableEvents | waiter.EventHUp | waiter.EventErr) + return nil + } + return e.shutdownLocked(flags) } diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go index 2e6ea06f5..2d5fdda19 100644 --- a/pkg/tcpip/transport/tcp/segment_test.go +++ b/pkg/tcpip/transport/tcp/segment_test.go @@ -34,7 +34,7 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW DataSize: seg.data.Size(), SegMemSize: seg.segMemSize(), } - if diff := cmp.Diff(got, want); diff != "" { + if diff := cmp.Diff(want, got); diff != "" { t.Errorf("%s differs (-want +got):\n%s", name, diff) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 58817371e..6f1ee3816 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1656,6 +1656,71 @@ func TestConnectBindToDevice(t *testing.T) { } } +func TestShutdownConnectingSocket(t *testing.T) { + for _, test := range []struct { + name string + shutdownMode tcpip.ShutdownFlags + }{ + {"ShutdownRead", tcpip.ShutdownRead}, + {"ShutdownWrite", tcpip.ShutdownWrite}, + {"ShutdownReadWrite", tcpip.ShutdownRead | tcpip.ShutdownWrite}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create an endpoint, don't handshake because we want to interfere with + // the handshake process. + c.Create(-1) + + waitEntry, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventHUp) + defer c.WQ.EventUnregister(&waitEntry) + + // Start connection attempt. + addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, c.EP.Connect(addr)); d != "" { + t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) + } + + // Check the SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + + if err := c.EP.Shutdown(test.shutdownMode); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + // The endpoint internal state is updated immediately. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + + select { + case <-ch: + default: + t.Fatal("endpoint was not notified") + } + + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrConnectionReset{}) + + // If the endpoint is not properly shutdown, it'll re-attempt to connect + // by sending another ACK packet. + c.CheckNoPacketTimeout("got an unexpected packet", tcp.InitialRTO+(500*time.Millisecond)) + }) + } +} + func TestSynSent(t *testing.T) { for _, test := range []struct { name string @@ -1679,7 +1744,7 @@ func TestSynSent(t *testing.T) { addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} err := c.EP.Connect(addr) - if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" { + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) } @@ -1995,7 +2060,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { ) // Cause a FIN to be generated. - c.EP.Shutdown(tcpip.ShutdownWrite) + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } // Make sure we get the FIN but DON't ACK IT. checker.IPv4(t, c.GetPacket(), @@ -2011,7 +2078,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { // Cause a RST to be generated by closing the read end now since we have // unread data. - c.EP.Shutdown(tcpip.ShutdownRead) + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } // Make sure we get the RST checker.IPv4(t, c.GetPacket(), diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 607182ffd..bbcb7e4fd 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -2067,6 +2067,69 @@ TEST_P(SimpleTcpSocketTest, GetSocketAcceptConnWithShutdown) { EXPECT_EQ(got, 0); } +void ShutdownConnectingSocket(int domain, int shutdown_mode) { + FileDescriptor bound_s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(domain, SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage bound_addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(domain)); + socklen_t bound_addrlen = sizeof(bound_addr); + + ASSERT_THAT(bind(bound_s.get(), AsSockAddr(&bound_addr), bound_addrlen), + SyscallSucceeds()); + + // Start listening. Use a zero backlog to only allow one connection in the + // accept queue. + ASSERT_THAT(listen(bound_s.get(), 0), SyscallSucceeds()); + + // Get the addresses the socket is bound to because the port is chosen by the + // stack. + ASSERT_THAT( + getsockname(bound_s.get(), AsSockAddr(&bound_addr), &bound_addrlen), + SyscallSucceeds()); + + // Establish a connection. But do not accept it. That way, subsequent + // connections will not get a SYN-ACK because the queue is full. + FileDescriptor connected_s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(domain, SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(connect(connected_s.get(), + reinterpret_cast<const struct sockaddr*>(&bound_addr), + bound_addrlen), + SyscallSucceeds()); + + FileDescriptor connecting_s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(domain, SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + ASSERT_THAT(connect(connecting_s.get(), + reinterpret_cast<const struct sockaddr*>(&bound_addr), + bound_addrlen), + SyscallFailsWithErrno(EINPROGRESS)); + + // Now the test: when a connecting socket is shutdown, the socket should enter + // an error state. + EXPECT_THAT(shutdown(connecting_s.get(), shutdown_mode), SyscallSucceeds()); + + // We don't need to specify any events to get POLLHUP or POLLERR because these + // are always tracked. + struct pollfd poll_fd = { + .fd = connecting_s.get(), + }; + + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 0), SyscallSucceedsWithValue(1)); + EXPECT_EQ(poll_fd.revents, POLLHUP | POLLERR); +} + +TEST_P(SimpleTcpSocketTest, ShutdownReadConnectingSocket) { + ShutdownConnectingSocket(GetParam(), SHUT_RD); +} + +TEST_P(SimpleTcpSocketTest, ShutdownWriteConnectingSocket) { + ShutdownConnectingSocket(GetParam(), SHUT_WR); +} + +TEST_P(SimpleTcpSocketTest, ShutdownReadWriteConnectingSocket) { + ShutdownConnectingSocket(GetParam(), SHUT_RDWR); +} + // Tests that connecting to an unspecified address results in ECONNREFUSED. TEST_P(SimpleTcpSocketTest, ConnectUnspecifiedAddress) { sockaddr_storage addr; |