summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/tests
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/tests')
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go35
1 files changed, 25 insertions, 10 deletions
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 0cb9d034e..38c2f321b 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -135,14 +135,15 @@ func TestForwarding(t *testing.T) {
name string
proto tcpip.TransportProtocolNumber
expectedConnectErr tcpip.Error
- setupServerSide func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
+ setupServer func(t *testing.T, ep tcpip.Endpoint)
+ setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
needRemoteAddr bool
}{
{
name: "UDP",
proto: udp.ProtocolNumber,
expectedConnectErr: nil,
- setupServerSide: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
t.Helper()
if err := ep.Connect(clientAddr); err != nil {
@@ -156,12 +157,16 @@ func TestForwarding(t *testing.T) {
name: "TCP",
proto: tcp.ProtocolNumber,
expectedConnectErr: &tcpip.ErrConnectStarted{},
- setupServerSide: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ setupServer: func(t *testing.T, ep tcpip.Endpoint) {
t.Helper()
if err := ep.Listen(1); err != nil {
t.Fatalf("ep.Listen(1): %s", err)
}
+ },
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
var addr tcpip.FullAddress
for {
newEP, wq, err := ep.Accept(&addr)
@@ -214,6 +219,9 @@ func TestForwarding(t *testing.T) {
t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
}
+ if subTest.setupServer != nil {
+ subTest.setupServer(t, epsAndAddrs.serverEP)
+ }
{
err := epsAndAddrs.clientEP.Connect(serverAddr)
if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
@@ -229,7 +237,7 @@ func TestForwarding(t *testing.T) {
serverEP := epsAndAddrs.serverEP
serverCH := epsAndAddrs.serverReadableCH
- if ep, ch := subTest.setupServerSide(t, serverEP, serverCH, clientAddr); ep != nil {
+ if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, clientAddr); ep != nil {
defer ep.Close()
serverEP = ep
serverCH = ch
@@ -256,13 +264,20 @@ func TestForwarding(t *testing.T) {
read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
t.Helper()
- // Wait for the endpoint to be readable.
- <-ch
var buf bytes.Buffer
- opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
- res, err := ep.Read(&buf, opts)
- if err != nil {
- t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ var res tcpip.ReadResult
+ for {
+ var err tcpip.Error
+ opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
+ res, err = ep.Read(&buf, opts)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ }
+ break
}
readResult := tcpip.ReadResult{