diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/segment_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 75 |
4 files changed, 95 insertions, 5 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5d8e18484..80cd07218 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -30,6 +30,10 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// InitialRTO is the initial retransmission timeout. +// https://github.com/torvalds/linux/blob/7c636d4d20f/include/net/tcp.h#L142 +const InitialRTO = time.Second + // maxSegmentsPerWake is the maximum number of segments to process in the main // protocol goroutine per wake-up. Yielding [after this number of segments are // processed] allows other events to be processed as well (e.g., timeouts, @@ -532,7 +536,7 @@ func (h *handshake) complete() tcpip.Error { defer s.Done() // Initialize the resend timer. - timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert) + timer, err := newBackoffTimer(h.ep.stack.Clock(), InitialRTO, MaxRTO, resendWaker.Assert) if err != nil { return err } @@ -578,6 +582,9 @@ func (h *handshake) complete() tcpip.Error { if (n¬ifyClose)|(n¬ifyAbort) != 0 { return &tcpip.ErrAborted{} } + if n¬ifyShutdown != 0 { + return &tcpip.ErrConnectionReset{} + } if n¬ifyDrain != 0 { for !h.ep.segmentQueue.empty() { s := h.ep.segmentQueue.dequeue() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 83c51e855..a3002abf3 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -187,6 +187,8 @@ const ( // say TIME_WAIT. notifyTickleWorker notifyError + // notifyShutdown means that a connecting socket was shutdown. + notifyShutdown ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -2384,6 +2386,18 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.LockUser() defer e.UnlockUser() + + if e.EndpointState().connecting() { + // When calling shutdown(2) on a connecting socket, the endpoint must + // enter the error state. But this logic cannot belong to the shutdownLocked + // method because that method is called during a close(2) (and closing a + // connecting socket is not an error). + e.resetConnectionLocked(&tcpip.ErrConnectionReset{}) + e.notifyProtocolGoroutine(notifyShutdown) + e.waiterQueue.Notify(waiter.WritableEvents | waiter.EventHUp | waiter.EventErr) + return nil + } + return e.shutdownLocked(flags) } diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go index 2e6ea06f5..2d5fdda19 100644 --- a/pkg/tcpip/transport/tcp/segment_test.go +++ b/pkg/tcpip/transport/tcp/segment_test.go @@ -34,7 +34,7 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW DataSize: seg.data.Size(), SegMemSize: seg.segMemSize(), } - if diff := cmp.Diff(got, want); diff != "" { + if diff := cmp.Diff(want, got); diff != "" { t.Errorf("%s differs (-want +got):\n%s", name, diff) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 58817371e..6f1ee3816 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1656,6 +1656,71 @@ func TestConnectBindToDevice(t *testing.T) { } } +func TestShutdownConnectingSocket(t *testing.T) { + for _, test := range []struct { + name string + shutdownMode tcpip.ShutdownFlags + }{ + {"ShutdownRead", tcpip.ShutdownRead}, + {"ShutdownWrite", tcpip.ShutdownWrite}, + {"ShutdownReadWrite", tcpip.ShutdownRead | tcpip.ShutdownWrite}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create an endpoint, don't handshake because we want to interfere with + // the handshake process. + c.Create(-1) + + waitEntry, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventHUp) + defer c.WQ.EventUnregister(&waitEntry) + + // Start connection attempt. + addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, c.EP.Connect(addr)); d != "" { + t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) + } + + // Check the SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + + if err := c.EP.Shutdown(test.shutdownMode); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + // The endpoint internal state is updated immediately. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + + select { + case <-ch: + default: + t.Fatal("endpoint was not notified") + } + + ept := endpointTester{c.EP} + ept.CheckReadError(t, &tcpip.ErrConnectionReset{}) + + // If the endpoint is not properly shutdown, it'll re-attempt to connect + // by sending another ACK packet. + c.CheckNoPacketTimeout("got an unexpected packet", tcp.InitialRTO+(500*time.Millisecond)) + }) + } +} + func TestSynSent(t *testing.T) { for _, test := range []struct { name string @@ -1679,7 +1744,7 @@ func TestSynSent(t *testing.T) { addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} err := c.EP.Connect(addr) - if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" { + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) } @@ -1995,7 +2060,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { ) // Cause a FIN to be generated. - c.EP.Shutdown(tcpip.ShutdownWrite) + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } // Make sure we get the FIN but DON't ACK IT. checker.IPv4(t, c.GetPacket(), @@ -2011,7 +2078,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { // Cause a RST to be generated by closing the read end now since we have // unread data. - c.EP.Shutdown(tcpip.ShutdownRead) + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } // Make sure we get the RST checker.IPv4(t, c.GetPacket(), |