diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 78 | ||||
-rw-r--r-- | test/packetimpact/tests/tcp_zero_receive_window_test.go | 20 |
3 files changed, 98 insertions, 12 deletions
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index f2b1b68da..405a6dce7 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -172,14 +172,12 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // If we started off with a window larger than what can he held in // the 16bit window field, we ceil the value to the max value. - // While ceiling, we still do not want to grow the right edge when - // not applicable. if scaledWnd > math.MaxUint16 { - if toGrow { - scaledWnd = seqnum.Size(math.MaxUint16) - } else { - scaledWnd = seqnum.Size(uint16(scaledWnd)) - } + scaledWnd = seqnum.Size(math.MaxUint16) + + // Ensure that the stashed receive window always reflects what + // is being advertised. + r.rcvWnd = scaledWnd << r.rcvWndScale } return r.rcvNxt, scaledWnd } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 7581bdc97..351a5e4f5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1932,6 +1932,84 @@ func TestFullWindowReceive(t *testing.T) { ) } +// Test the stack receive window advertisement on receiving segments smaller than +// segment overhead. It tests for the right edge of the window to not grow when +// the endpoint is not being read from. +func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + opt := tcpip.TCPReceiveBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize, + Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)), + } + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } + + c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Bump up the receive buffer size such that, when the receive window grows, + // the scaled window exceeds maxUint16. + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err) + } + + // Keep the payload size < segment overhead and such that it is a multiple + // of the window scaled value. This enables the test to perform equality + // checks on the incoming receive window. + payload := generateRandomPayload(t, (tcp.SegSize-1)&(1<<c.RcvdWindowScale)) + payloadLen := seqnum.Size(len(payload)) + iss := seqnum.Value(789) + seqNum := iss.Add(1) + + // Send payload to the endpoint and return the advertised receive window + // from the endpoint. + getIncomingRcvWnd := func() uint32 { + c.SendPacket(payload, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: seqNum, + AckNum: c.IRS.Add(1), + Flags: header.TCPFlagAck, + RcvWnd: 30000, + }) + seqNum = seqNum.Add(payloadLen) + + pkt := c.GetPacket() + return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale + } + + // Read the advertised receive window with the ACK for payload. + rcvWnd := getIncomingRcvWnd() + + // Check if the subsequent ACK to our send has not grown the right edge of + // the window. + if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } + + // Read the data so that the subsequent ACK from the endpoint + // grows the right edge of the window. + if _, _, err := c.EP.Read(nil); err != nil { + t.Fatalf("got Read(nil) = %s", err) + } + + // Check if we have received max uint16 as our advertised + // scaled window now after a read above. + maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale) + if got, want := getIncomingRcvWnd(), maxRcv; got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } + + // Check if the subsequent ACK to our send has not grown the right edge of + // the window. + if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } +} + func TestNoWindowShrinking(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/test/packetimpact/tests/tcp_zero_receive_window_test.go b/test/packetimpact/tests/tcp_zero_receive_window_test.go index b494e179b..d06690705 100644 --- a/test/packetimpact/tests/tcp_zero_receive_window_test.go +++ b/test/packetimpact/tests/tcp_zero_receive_window_test.go @@ -15,7 +15,6 @@ package tcp_zero_receive_window_test import ( - "context" "flag" "fmt" "testing" @@ -49,12 +48,24 @@ func TestZeroReceiveWindow(t *testing.T) { samplePayload := &testbench.Payload{Bytes: testbench.GenerateRandomPayload(t, payloadLen)} // Expect the DUT to eventually advertise zero receive window. // The test would timeout otherwise. - for { + for readOnce := false; ; { conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected packet was not received: %s", err) } + // Read once to trigger the subsequent window update from the + // DUT to grow the right edge of the receive window from what + // was advertised in the SYN-ACK. This ensures that we test + // for the full default buffer size (1MB on gVisor at the time + // of writing this comment), thus testing for cases when the + // scaled receive window size ends up > 65535 (0xffff). + if !readOnce { + if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != payloadLen { + t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen) + } + readOnce = true + } windowSize := *gotTCP.WindowSize t.Logf("got window size = %d", windowSize) if windowSize == 0 { @@ -94,10 +105,9 @@ func TestNonZeroReceiveWindow(t *testing.T) { if err != nil { t.Fatalf("expected packet was not received: %s", err) } - if ret, _, err := dut.RecvWithErrno(context.Background(), t, acceptFd, int32(payloadLen), 0); ret == -1 { - t.Fatalf("dut.RecvWithErrno(ctx, t, %d, %d, 0) = %d,_, %s", acceptFd, payloadLen, ret, err) + if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != payloadLen { + t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen) } - if *gotTCP.WindowSize == 0 { t.Fatalf("expected non-zero receive window.") } |