diff options
-rw-r--r-- | pkg/tcpip/tests/integration/forward_test.go | 35 |
1 files changed, 25 insertions, 10 deletions
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 0cb9d034e..38c2f321b 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -135,14 +135,15 @@ func TestForwarding(t *testing.T) { name string proto tcpip.TransportProtocolNumber expectedConnectErr tcpip.Error - setupServerSide func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) + setupServer func(t *testing.T, ep tcpip.Endpoint) + setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) needRemoteAddr bool }{ { name: "UDP", proto: udp.ProtocolNumber, expectedConnectErr: nil, - setupServerSide: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { t.Helper() if err := ep.Connect(clientAddr); err != nil { @@ -156,12 +157,16 @@ func TestForwarding(t *testing.T) { name: "TCP", proto: tcp.ProtocolNumber, expectedConnectErr: &tcpip.ErrConnectStarted{}, - setupServerSide: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + setupServer: func(t *testing.T, ep tcpip.Endpoint) { t.Helper() if err := ep.Listen(1); err != nil { t.Fatalf("ep.Listen(1): %s", err) } + }, + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + t.Helper() + var addr tcpip.FullAddress for { newEP, wq, err := ep.Accept(&addr) @@ -214,6 +219,9 @@ func TestForwarding(t *testing.T) { t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) } + if subTest.setupServer != nil { + subTest.setupServer(t, epsAndAddrs.serverEP) + } { err := epsAndAddrs.clientEP.Connect(serverAddr) if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { @@ -229,7 +237,7 @@ func TestForwarding(t *testing.T) { serverEP := epsAndAddrs.serverEP serverCH := epsAndAddrs.serverReadableCH - if ep, ch := subTest.setupServerSide(t, serverEP, serverCH, clientAddr); ep != nil { + if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, clientAddr); ep != nil { defer ep.Close() serverEP = ep serverCH = ch @@ -256,13 +264,20 @@ func TestForwarding(t *testing.T) { read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { t.Helper() - // Wait for the endpoint to be readable. - <-ch var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} - res, err := ep.Read(&buf, opts) - if err != nil { - t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + var res tcpip.ReadResult + for { + var err tcpip.Error + opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} + res, err = ep.Read(&buf, opts) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + } + break } readResult := tcpip.ReadResult{ |