diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 87 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 6 |
3 files changed, 45 insertions, 52 deletions
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index f9460bd51..ad2c6f601 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "context" "encoding/binary" "fmt" "testing" @@ -405,7 +406,7 @@ func TestDADResolve(t *testing.T) { // Validate the sent Neighbor Solicitation messages. for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p := <-e.C + p, _ := e.ReadContext(context.Background()) // Make sure its an IPv6 packet. if p.Proto != header.IPv6ProtocolNumber { @@ -3285,29 +3286,29 @@ func TestRouterSolicitation(t *testing.T) { e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1) waitForPkt := func(timeout time.Duration) { t.Helper() - select { - case p := <-e.C: - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, - p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(), - ) - - case <-time.After(timeout): + ctx, _ := context.WithTimeout(context.Background(), timeout) + p, ok := e.ReadContext(ctx) + if !ok { t.Fatal("timed out waiting for packet") + return } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, + p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(), + ) } waitForNothing := func(timeout time.Duration) { t.Helper() - select { - case <-e.C: + ctx, _ := context.WithTimeout(context.Background(), timeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet") - case <-time.After(timeout): } } s := stack.New(stack.Options{ @@ -3362,20 +3363,21 @@ func TestStopStartSolicitingRouters(t *testing.T) { e := channel.New(maxRtrSolicitations, 1280, linkAddr1) waitForPkt := func(timeout time.Duration) { t.Helper() - select { - case p := <-e.C: - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS()) - - case <-time.After(timeout): + ctx, _ := context.WithTimeout(context.Background(), timeout) + p, ok := e.ReadContext(ctx) + if !ok { t.Fatal("timed out waiting for packet") + return } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS()) } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, @@ -3391,23 +3393,20 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Enable forwarding which should stop router solicitations. s.SetForwarding(true) - select { - case <-e.C: + ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { // A single RS may have been sent before forwarding was enabled. - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) + if _, ok = e.ReadContext(ctx); ok { t.Fatal("Should not have sent more than one RS message") - case <-time.After(interval + defaultTimeout): } - case <-time.After(delay + defaultTimeout): } // Enabling forwarding again should do nothing. s.SetForwarding(true) - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after becoming a router") - case <-time.After(delay + defaultTimeout): } // Disable forwarding which should start router solicitations. @@ -3415,17 +3414,15 @@ func TestStopStartSolicitingRouters(t *testing.T) { waitForPkt(delay + defaultAsyncEventTimeout) waitForPkt(interval + defaultAsyncEventTimeout) waitForPkt(interval + defaultAsyncEventTimeout) - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") - case <-time.After(interval + defaultTimeout): } // Disabling forwarding again should do nothing. s.SetForwarding(false) - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after becoming a router") - case <-time.After(delay + defaultTimeout): } } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index dad288642..834fe9487 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1880,9 +1880,7 @@ func TestNICForwarding(t *testing.T) { Data: buf.ToVectorisedView(), }) - select { - case <-ep2.C: - default: + if _, ok := ep2.Read(); !ok { t.Fatal("Packet not forwarded") } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index f50604a8a..869c69a6d 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -623,10 +623,8 @@ func TestTransportForwarding(t *testing.T) { t.Fatalf("Write failed: %v", err) } - var p channel.PacketInfo - select { - case p = <-ep2.C: - default: + p, ok := ep2.Read() + if !ok { t.Fatal("Response packet not forwarded") } |