From d29e59af9fbd420e34378bcbf7ae543134070217 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Mon, 27 Jan 2020 10:04:07 -0800 Subject: Standardize on tools directory. PiperOrigin-RevId: 291745021 --- pkg/tcpip/link/channel/BUILD | 3 +-- pkg/tcpip/link/fdbased/BUILD | 6 ++---- pkg/tcpip/link/loopback/BUILD | 3 +-- pkg/tcpip/link/muxed/BUILD | 6 ++---- pkg/tcpip/link/rawfile/BUILD | 3 +-- pkg/tcpip/link/sharedmem/BUILD | 6 ++---- pkg/tcpip/link/sharedmem/pipe/BUILD | 6 ++---- pkg/tcpip/link/sharedmem/queue/BUILD | 6 ++---- pkg/tcpip/link/sniffer/BUILD | 3 +-- pkg/tcpip/link/tun/BUILD | 3 +-- pkg/tcpip/link/waitable/BUILD | 6 ++---- 11 files changed, 17 insertions(+), 34 deletions(-) (limited to 'pkg/tcpip/link') diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 7dbc05754..3974c464e 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "channel", srcs = ["channel.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 66cc53ed4..abe725548 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -13,7 +12,6 @@ go_library( "mmap_unsafe.go", "packet_dispatchers.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -30,7 +28,7 @@ go_test( name = "fdbased_test", size = "small", srcs = ["endpoint_test.go"], - embed = [":fdbased"], + library = ":fdbased", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD index f35fcdff4..6bf3805b7 100644 --- a/pkg/tcpip/link/loopback/BUILD +++ b/pkg/tcpip/link/loopback/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "loopback", srcs = ["loopback.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index 1ac7948b6..82b441b79 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "muxed", srcs = ["injectable.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -19,7 +17,7 @@ go_test( name = "muxed_test", size = "small", srcs = ["injectable_test.go"], - embed = [":muxed"], + library = ":muxed", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index d8211e93d..14b527bc2 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -12,7 +12,6 @@ go_library( "errors.go", "rawfile_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index 09165dd4c..13243ebbb 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "sharedmem_unsafe.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem", visibility = ["//visibility:public"], deps = [ "//pkg/log", @@ -30,7 +28,7 @@ go_test( srcs = [ "sharedmem_test.go", ], - embed = [":sharedmem"], + library = ":sharedmem", deps = [ "//pkg/sync", "//pkg/tcpip", diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index a0d4ad0be..87020ec08 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "rx.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe", visibility = ["//visibility:public"], ) @@ -20,6 +18,6 @@ go_test( srcs = [ "pipe_test.go", ], - embed = [":pipe"], + library = ":pipe", deps = ["//pkg/sync"], ) diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD index 8c9234d54..3ba06af73 100644 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ b/pkg/tcpip/link/sharedmem/queue/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "rx.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue", visibility = ["//visibility:public"], deps = [ "//pkg/log", @@ -22,7 +20,7 @@ go_test( srcs = [ "queue_test.go", ], - embed = [":queue"], + library = ":queue", deps = [ "//pkg/tcpip/link/sharedmem/pipe", ], diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index d6ae0368a..230a8d53a 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "pcap.go", "sniffer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer", visibility = ["//visibility:public"], deps = [ "//pkg/log", diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index a71a493fc..e5096ea38 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -1,10 +1,9 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "tun", srcs = ["tun_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun", visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 134837943..0956d2c65 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -8,7 +7,6 @@ go_library( srcs = [ "waitable.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable", visibility = ["//visibility:public"], deps = [ "//pkg/gate", @@ -23,7 +21,7 @@ go_test( srcs = [ "waitable_test.go", ], - embed = [":waitable"], + library = ":waitable", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", -- cgit v1.2.3 From 6b14be4246e8ed3779bf69dbd59e669caf3f5704 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Mon, 27 Jan 2020 10:08:18 -0800 Subject: Refactor to hide C from channel.Endpoint. This is to aid later implementation for /dev/net/tun device. PiperOrigin-RevId: 291746025 --- pkg/tcpip/link/channel/channel.go | 43 ++++- pkg/tcpip/network/arp/arp_test.go | 16 +- pkg/tcpip/network/ipv6/icmp_test.go | 7 +- pkg/tcpip/stack/ndp_test.go | 87 +++++----- pkg/tcpip/stack/stack_test.go | 4 +- pkg/tcpip/stack/transport_test.go | 6 +- pkg/tcpip/transport/tcp/testing/context/context.go | 86 +++++----- pkg/tcpip/transport/udp/udp_test.go | 182 +++++++++++---------- 8 files changed, 229 insertions(+), 202 deletions(-) (limited to 'pkg/tcpip/link') diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 70188551f..71b9da797 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -18,6 +18,8 @@ package channel import ( + "context" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -38,25 +40,52 @@ type Endpoint struct { linkAddr tcpip.LinkAddress GSO bool - // C is where outbound packets are queued. - C chan PacketInfo + // c is where outbound packets are queued. + c chan PacketInfo } // New creates a new channel endpoint. func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { return &Endpoint{ - C: make(chan PacketInfo, size), + c: make(chan PacketInfo, size), mtu: mtu, linkAddr: linkAddr, } } +// Close closes e. Further packet injections will panic. Reads continue to +// succeed until all packets are read. +func (e *Endpoint) Close() { + close(e.c) +} + +// Read does non-blocking read for one packet from the outbound packet queue. +func (e *Endpoint) Read() (PacketInfo, bool) { + select { + case pkt := <-e.c: + return pkt, true + default: + return PacketInfo{}, false + } +} + +// ReadContext does blocking read for one packet from the outbound packet queue. +// It can be cancelled by ctx, and in this case, it returns false. +func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) { + select { + case pkt := <-e.c: + return pkt, true + case <-ctx.Done(): + return PacketInfo{}, false + } +} + // Drain removes all outbound packets from the channel and counts them. func (e *Endpoint) Drain() int { c := 0 for { select { - case <-e.C: + case <-e.c: c++ default: return c @@ -125,7 +154,7 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.Ne } select { - case e.C <- p: + case e.c <- p: default: } @@ -150,7 +179,7 @@ packetLoop: } select { - case e.C <- p: + case e.c <- p: n++ default: break packetLoop @@ -169,7 +198,7 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { } select { - case e.C <- p: + case e.c <- p: default: } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 8e6048a21..03cf03b6d 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -15,6 +15,7 @@ package arp_test import ( + "context" "strconv" "testing" "time" @@ -83,7 +84,7 @@ func newTestContext(t *testing.T) *testContext { } func (c *testContext) cleanup() { - close(c.linkEP.C) + c.linkEP.Close() } func TestDirectRequest(t *testing.T) { @@ -110,7 +111,7 @@ func TestDirectRequest(t *testing.T) { for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { t.Run(strconv.Itoa(i), func(t *testing.T) { inject(address) - pi := <-c.linkEP.C + pi, _ := c.linkEP.ReadContext(context.Background()) if pi.Proto != arp.ProtocolNumber { t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } @@ -134,12 +135,11 @@ func TestDirectRequest(t *testing.T) { } inject(stackAddrBad) - select { - case pkt := <-c.linkEP.C: + // Sleep tests are gross, but this will only potentially flake + // if there's a bug. If there is no bug this will reliably + // succeed. + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + if pkt, ok := c.linkEP.ReadContext(ctx); ok { t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) - case <-time.After(100 * time.Millisecond): - // Sleep tests are gross, but this will only potentially flake - // if there's a bug. If there is no bug this will reliably - // succeed. } } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index a2fdc5dcd..7a6820643 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -15,6 +15,7 @@ package ipv6 import ( + "context" "reflect" "strings" "testing" @@ -264,8 +265,8 @@ func newTestContext(t *testing.T) *testContext { } func (c *testContext) cleanup() { - close(c.linkEP0.C) - close(c.linkEP1.C) + c.linkEP0.Close() + c.linkEP1.Close() } type routeArgs struct { @@ -276,7 +277,7 @@ type routeArgs struct { func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { t.Helper() - pi := <-args.src.C + pi, _ := args.src.ReadContext(context.Background()) { views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index f9460bd51..ad2c6f601 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "context" "encoding/binary" "fmt" "testing" @@ -405,7 +406,7 @@ func TestDADResolve(t *testing.T) { // Validate the sent Neighbor Solicitation messages. for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p := <-e.C + p, _ := e.ReadContext(context.Background()) // Make sure its an IPv6 packet. if p.Proto != header.IPv6ProtocolNumber { @@ -3285,29 +3286,29 @@ func TestRouterSolicitation(t *testing.T) { e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1) waitForPkt := func(timeout time.Duration) { t.Helper() - select { - case p := <-e.C: - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, - p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(), - ) - - case <-time.After(timeout): + ctx, _ := context.WithTimeout(context.Background(), timeout) + p, ok := e.ReadContext(ctx) + if !ok { t.Fatal("timed out waiting for packet") + return } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, + p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(), + ) } waitForNothing := func(timeout time.Duration) { t.Helper() - select { - case <-e.C: + ctx, _ := context.WithTimeout(context.Background(), timeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet") - case <-time.After(timeout): } } s := stack.New(stack.Options{ @@ -3362,20 +3363,21 @@ func TestStopStartSolicitingRouters(t *testing.T) { e := channel.New(maxRtrSolicitations, 1280, linkAddr1) waitForPkt := func(timeout time.Duration) { t.Helper() - select { - case p := <-e.C: - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS()) - - case <-time.After(timeout): + ctx, _ := context.WithTimeout(context.Background(), timeout) + p, ok := e.ReadContext(ctx) + if !ok { t.Fatal("timed out waiting for packet") + return } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS()) } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, @@ -3391,23 +3393,20 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Enable forwarding which should stop router solicitations. s.SetForwarding(true) - select { - case <-e.C: + ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { // A single RS may have been sent before forwarding was enabled. - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) + if _, ok = e.ReadContext(ctx); ok { t.Fatal("Should not have sent more than one RS message") - case <-time.After(interval + defaultTimeout): } - case <-time.After(delay + defaultTimeout): } // Enabling forwarding again should do nothing. s.SetForwarding(true) - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after becoming a router") - case <-time.After(delay + defaultTimeout): } // Disable forwarding which should start router solicitations. @@ -3415,17 +3414,15 @@ func TestStopStartSolicitingRouters(t *testing.T) { waitForPkt(delay + defaultAsyncEventTimeout) waitForPkt(interval + defaultAsyncEventTimeout) waitForPkt(interval + defaultAsyncEventTimeout) - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") - case <-time.After(interval + defaultTimeout): } // Disabling forwarding again should do nothing. s.SetForwarding(false) - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after becoming a router") - case <-time.After(delay + defaultTimeout): } } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index dad288642..834fe9487 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1880,9 +1880,7 @@ func TestNICForwarding(t *testing.T) { Data: buf.ToVectorisedView(), }) - select { - case <-ep2.C: - default: + if _, ok := ep2.Read(); !ok { t.Fatal("Packet not forwarded") } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index f50604a8a..869c69a6d 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -623,10 +623,8 @@ func TestTransportForwarding(t *testing.T) { t.Fatalf("Write failed: %v", err) } - var p channel.PacketInfo - select { - case p = <-ep2.C: - default: + p, ok := ep2.Read() + if !ok { t.Fatal("Response packet not forwarded") } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 822907998..730ac4292 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -18,6 +18,7 @@ package context import ( "bytes" + "context" "testing" "time" @@ -215,11 +216,9 @@ func (c *Context) Stack() *stack.Stack { func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { c.t.Helper() - select { - case <-c.linkEP.C: + ctx, _ := context.WithTimeout(context.Background(), wait) + if _, ok := c.linkEP.ReadContext(ctx); ok { c.t.Fatal(errMsg) - - case <-time.After(wait): } } @@ -234,27 +233,27 @@ func (c *Context) CheckNoPacket(errMsg string) { // 2 seconds. func (c *Context) GetPacket() []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + c.t.Fatalf("Packet wasn't written out") + return nil + } - if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { - c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) - } + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - case <-time.After(2 * time.Second): - c.t.Fatalf("Packet wasn't written out") + if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { + c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) } - return nil + checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) + return b } // GetPacketNonBlocking reads a packet from the link layer endpoint @@ -263,20 +262,21 @@ func (c *Context) GetPacket() []byte { // nil immediately. func (c *Context) GetPacketNonBlocking() []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b - default: + p, ok := c.linkEP.Read() + if !ok { return nil } + + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + + checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) + return b } // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. @@ -484,23 +484,23 @@ func (c *Context) CreateV6Endpoint(v6only bool) { // and asserts that it is an IPv6 Packet with the expected src/dest addresses. func (c *Context) GetV6Packet() []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != ipv6.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) - } - b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) - copy(b, p.Pkt.Header.View()) - copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) - - checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) - return b - case <-time.After(2 * time.Second): + ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { c.t.Fatalf("Packet wasn't written out") + return nil + } + + if p.Proto != ipv6.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) } + b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) + copy(b, p.Pkt.Header.View()) + copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) - return nil + checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) + return b } // SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index c6927cfe3..f0ff3fe71 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,6 +16,7 @@ package udp_test import ( "bytes" + "context" "fmt" "math/rand" "testing" @@ -357,30 +358,29 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != flow.netProto() { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) - } - - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - - h := flow.header4Tuple(outgoing) - checkers := append( - checkers, - checker.SrcAddr(h.srcAddr.Addr), - checker.DstAddr(h.dstAddr.Addr), - checker.UDP(checker.DstPort(h.dstAddr.Port)), - ) - flow.checkerFn()(c.t, b, checkers...) - return b - - case <-time.After(2 * time.Second): + ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { c.t.Fatalf("Packet wasn't written out") + return nil } - return nil + if p.Proto != flow.netProto() { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) + } + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + + h := flow.header4Tuple(outgoing) + checkers = append( + checkers, + checker.SrcAddr(h.srcAddr.Addr), + checker.DstAddr(h.dstAddr.Addr), + checker.UDP(checker.DstPort(h.dstAddr.Port)), + ) + flow.checkerFn()(c.t, b, checkers...) + return b } // injectPacket creates a packet of the given flow and with the given payload, @@ -1541,48 +1541,50 @@ func TestV4UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - select { - case p := <-c.linkEP.C: + ctx, _ := context.WithTimeout(context.Background(), time.Second) + if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) - case <-time.After(1 * time.Second): - return } + return } - select { - case p := <-c.linkEP.C: - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) - if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } + // ICMP required. + ctx, _ := context.WithTimeout(context.Background(), time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + t.Fatalf("packet wasn't written out") + return + } - hdr := header.IPv4(pkt) - checker.IPv4(t, hdr, checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + var pkt []byte + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) + if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } - icmpPkt := header.ICMPv4(hdr.Payload()) - payloadIPHeader := header.IPv4(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize - } + hdr := header.IPv4(pkt) + checker.IPv4(t, hdr, checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4PortUnreachable))) - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + icmpPkt := header.ICMPv4(hdr.Payload()) + payloadIPHeader := header.IPv4(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize + } - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %d, want: %d", got, want) - } - case <-time.After(1 * time.Second): - t.Fatalf("packet wasn't written out") + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %d, want: %d", got, want) } }) } @@ -1615,47 +1617,49 @@ func TestV6UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - select { - case p := <-c.linkEP.C: + ctx, _ := context.WithTimeout(context.Background(), time.Second) + if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) - case <-time.After(1 * time.Second): - return } + return } - select { - case p := <-c.linkEP.C: - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) - if got, want := len(pkt), header.IPv6MinimumMTU; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } + // ICMP required. + ctx, _ := context.WithTimeout(context.Background(), time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + t.Fatalf("packet wasn't written out") + return + } + + var pkt []byte + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) + if got, want := len(pkt), header.IPv6MinimumMTU; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } - hdr := header.IPv6(pkt) - checker.IPv6(t, hdr, checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6PortUnreachable))) + hdr := header.IPv6(pkt) + checker.IPv6(t, hdr, checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6PortUnreachable))) - icmpPkt := header.ICMPv6(hdr.Payload()) - payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize - } - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) + icmpPkt := header.ICMPv6(hdr.Payload()) + payloadIPHeader := header.IPv6(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize + } + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %v, want: %v", got, want) - } - case <-time.After(1 * time.Second): - t.Fatalf("packet wasn't written out") + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %v, want: %v", got, want) } }) } -- cgit v1.2.3