diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 8 | ||||
-rw-r--r-- | test/packetimpact/testbench/dut.go | 4 | ||||
-rw-r--r-- | test/packetimpact/tests/BUILD | 1 | ||||
-rw-r--r-- | test/packetimpact/tests/tcp_syncookie_test.go | 18 |
5 files changed, 31 insertions, 7 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 2b5abd3ee..d807b13b7 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -740,6 +740,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err mss: rcvdSynOptions.MSS, }) + // Requeue the segment if the ACK completing the handshake has more info + // to be procesed by the newly established endpoint. + if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && n.enqueueSegment(s) { + s.incRef() + n.newSegmentWaker.Assert() + } + // Do the delivery in a separate goroutine so // that we don't block the listen loop in case // the application is slow to accept or stops diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 9958547d3..2137ebc25 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -406,11 +406,11 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { h.ep.transitionToStateEstablishedLocked(h) - // If the segment has data then requeue it for the receiver - // to process it again once main loop is started. - if s.data.Size() > 0 { + // Requeue the segment if the ACK completing the handshake has more info + // to be procesed by the newly established endpoint. + if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && h.ep.enqueueSegment(s) { s.incRef() - h.ep.enqueueSegment(s) + h.ep.newSegmentWaker.Assert() } return nil } diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index 269e163bb..0cac0bf1b 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -498,8 +498,8 @@ func (dut *DUT) PollOne(t *testing.T, fd int32, events int16, timeout time.Durat if readyFd := pfds[0].Fd; readyFd != fd { t.Fatalf("Poll returned an fd %d that was not requested (%d)", readyFd, fd) } - if got, want := pfds[0].Revents, int16(events); got&want == 0 { - t.Fatalf("Poll returned no events in our interest, got: %#b, want: %#b", got, want) + if got, want := pfds[0].Revents, int16(events); got&want != want { + t.Fatalf("Poll returned events does not include all of the interested events, got: %#b, want: %#b", got, want) } } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 37b59b1d9..4cff0cf4c 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -384,6 +384,7 @@ packetimpact_testbench( deps = [ "//pkg/tcpip/header", "//test/packetimpact/testbench", + "@com_github_google_go_cmp//cmp:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/test/packetimpact/tests/tcp_syncookie_test.go b/test/packetimpact/tests/tcp_syncookie_test.go index 1a016bd1a..6be09996b 100644 --- a/test/packetimpact/tests/tcp_syncookie_test.go +++ b/test/packetimpact/tests/tcp_syncookie_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/test/packetimpact/testbench" @@ -114,7 +115,10 @@ func TestTCPSynCookie(t *testing.T) { t.Fatalf("dut.Poll(...) = %d, want = %d", got, want) } - c.conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(test.flags)}) + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + c.conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(test.flags)}, samplePayload) pfds = dut.Poll(t, []unix.PollFd{{Fd: listenFD, Events: unix.POLLIN}}, time.Second) want := 0 if test.accept { @@ -126,6 +130,18 @@ func TestTCPSynCookie(t *testing.T) { // Accept the connection to enable poll on any subsequent connection. if test.accept { fd, _ := dut.Accept(t, listenFD) + if test.flags.Contains(header.TCPFlagFin) { + if dut.Uname.IsLinux() { + dut.PollOne(t, fd, unix.POLLIN|unix.POLLRDHUP, time.Second) + } else { + // TODO(gvisor.dev/issue/6015): Notify POLLIN|POLLRDHUP on incoming FIN. + dut.PollOne(t, fd, unix.POLLIN, time.Second) + } + } + got := dut.Recv(t, fd, int32(len(sampleData)), 0) + if diff := cmp.Diff(got, sampleData); diff != "" { + t.Fatalf("dut.Recv: data mismatch (-want +got):\n%s", diff) + } dut.Close(t, fd) } }) |