From d29e59af9fbd420e34378bcbf7ae543134070217 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Mon, 27 Jan 2020 10:04:07 -0800 Subject: Standardize on tools directory. PiperOrigin-RevId: 291745021 --- pkg/tcpip/network/BUILD | 2 +- pkg/tcpip/network/arp/BUILD | 4 +--- pkg/tcpip/network/fragmentation/BUILD | 6 ++---- pkg/tcpip/network/hash/BUILD | 3 +-- pkg/tcpip/network/ipv4/BUILD | 4 +--- pkg/tcpip/network/ipv6/BUILD | 6 ++---- 6 files changed, 8 insertions(+), 17 deletions(-) (limited to 'pkg/tcpip/network') diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 9d16ff8c9..6a4839fb8 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index e7617229b..eddf7b725 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "arp", srcs = ["arp.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index ed16076fd..d1c728ccf 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -24,7 +23,6 @@ go_library( "reassembler.go", "reassembler_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation", visibility = ["//visibility:public"], deps = [ "//pkg/log", @@ -42,6 +40,6 @@ go_test( "fragmentation_test.go", "reassembler_test.go", ], - embed = [":fragmentation"], + library = ":fragmentation", deps = ["//pkg/tcpip/buffer"], ) diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD index e6db5c0b0..872165866 100644 --- a/pkg/tcpip/network/hash/BUILD +++ b/pkg/tcpip/network/hash/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "hash", srcs = ["hash.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/hash", visibility = ["//visibility:public"], deps = [ "//pkg/rand", diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 4e2aae9a3..0fef2b1f1 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "icmp.go", "ipv4.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index e4e273460..fb11874c6 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "icmp.go", "ipv6.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -27,7 +25,7 @@ go_test( "ipv6_test.go", "ndp_test.go", ], - embed = [":ipv6"], + library = ":ipv6", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", -- cgit v1.2.3 From 6b14be4246e8ed3779bf69dbd59e669caf3f5704 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Mon, 27 Jan 2020 10:08:18 -0800 Subject: Refactor to hide C from channel.Endpoint. This is to aid later implementation for /dev/net/tun device. PiperOrigin-RevId: 291746025 --- pkg/tcpip/link/channel/channel.go | 43 ++++- pkg/tcpip/network/arp/arp_test.go | 16 +- pkg/tcpip/network/ipv6/icmp_test.go | 7 +- pkg/tcpip/stack/ndp_test.go | 87 +++++----- pkg/tcpip/stack/stack_test.go | 4 +- pkg/tcpip/stack/transport_test.go | 6 +- pkg/tcpip/transport/tcp/testing/context/context.go | 86 +++++----- pkg/tcpip/transport/udp/udp_test.go | 182 +++++++++++---------- 8 files changed, 229 insertions(+), 202 deletions(-) (limited to 'pkg/tcpip/network') diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 70188551f..71b9da797 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -18,6 +18,8 @@ package channel import ( + "context" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -38,25 +40,52 @@ type Endpoint struct { linkAddr tcpip.LinkAddress GSO bool - // C is where outbound packets are queued. - C chan PacketInfo + // c is where outbound packets are queued. + c chan PacketInfo } // New creates a new channel endpoint. func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { return &Endpoint{ - C: make(chan PacketInfo, size), + c: make(chan PacketInfo, size), mtu: mtu, linkAddr: linkAddr, } } +// Close closes e. Further packet injections will panic. Reads continue to +// succeed until all packets are read. +func (e *Endpoint) Close() { + close(e.c) +} + +// Read does non-blocking read for one packet from the outbound packet queue. +func (e *Endpoint) Read() (PacketInfo, bool) { + select { + case pkt := <-e.c: + return pkt, true + default: + return PacketInfo{}, false + } +} + +// ReadContext does blocking read for one packet from the outbound packet queue. +// It can be cancelled by ctx, and in this case, it returns false. +func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) { + select { + case pkt := <-e.c: + return pkt, true + case <-ctx.Done(): + return PacketInfo{}, false + } +} + // Drain removes all outbound packets from the channel and counts them. func (e *Endpoint) Drain() int { c := 0 for { select { - case <-e.C: + case <-e.c: c++ default: return c @@ -125,7 +154,7 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.Ne } select { - case e.C <- p: + case e.c <- p: default: } @@ -150,7 +179,7 @@ packetLoop: } select { - case e.C <- p: + case e.c <- p: n++ default: break packetLoop @@ -169,7 +198,7 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { } select { - case e.C <- p: + case e.c <- p: default: } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 8e6048a21..03cf03b6d 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -15,6 +15,7 @@ package arp_test import ( + "context" "strconv" "testing" "time" @@ -83,7 +84,7 @@ func newTestContext(t *testing.T) *testContext { } func (c *testContext) cleanup() { - close(c.linkEP.C) + c.linkEP.Close() } func TestDirectRequest(t *testing.T) { @@ -110,7 +111,7 @@ func TestDirectRequest(t *testing.T) { for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { t.Run(strconv.Itoa(i), func(t *testing.T) { inject(address) - pi := <-c.linkEP.C + pi, _ := c.linkEP.ReadContext(context.Background()) if pi.Proto != arp.ProtocolNumber { t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } @@ -134,12 +135,11 @@ func TestDirectRequest(t *testing.T) { } inject(stackAddrBad) - select { - case pkt := <-c.linkEP.C: + // Sleep tests are gross, but this will only potentially flake + // if there's a bug. If there is no bug this will reliably + // succeed. + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + if pkt, ok := c.linkEP.ReadContext(ctx); ok { t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) - case <-time.After(100 * time.Millisecond): - // Sleep tests are gross, but this will only potentially flake - // if there's a bug. If there is no bug this will reliably - // succeed. } } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index a2fdc5dcd..7a6820643 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -15,6 +15,7 @@ package ipv6 import ( + "context" "reflect" "strings" "testing" @@ -264,8 +265,8 @@ func newTestContext(t *testing.T) *testContext { } func (c *testContext) cleanup() { - close(c.linkEP0.C) - close(c.linkEP1.C) + c.linkEP0.Close() + c.linkEP1.Close() } type routeArgs struct { @@ -276,7 +277,7 @@ type routeArgs struct { func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { t.Helper() - pi := <-args.src.C + pi, _ := args.src.ReadContext(context.Background()) { views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index f9460bd51..ad2c6f601 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "context" "encoding/binary" "fmt" "testing" @@ -405,7 +406,7 @@ func TestDADResolve(t *testing.T) { // Validate the sent Neighbor Solicitation messages. for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p := <-e.C + p, _ := e.ReadContext(context.Background()) // Make sure its an IPv6 packet. if p.Proto != header.IPv6ProtocolNumber { @@ -3285,29 +3286,29 @@ func TestRouterSolicitation(t *testing.T) { e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1) waitForPkt := func(timeout time.Duration) { t.Helper() - select { - case p := <-e.C: - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, - p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(), - ) - - case <-time.After(timeout): + ctx, _ := context.WithTimeout(context.Background(), timeout) + p, ok := e.ReadContext(ctx) + if !ok { t.Fatal("timed out waiting for packet") + return } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, + p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(), + ) } waitForNothing := func(timeout time.Duration) { t.Helper() - select { - case <-e.C: + ctx, _ := context.WithTimeout(context.Background(), timeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet") - case <-time.After(timeout): } } s := stack.New(stack.Options{ @@ -3362,20 +3363,21 @@ func TestStopStartSolicitingRouters(t *testing.T) { e := channel.New(maxRtrSolicitations, 1280, linkAddr1) waitForPkt := func(timeout time.Duration) { t.Helper() - select { - case p := <-e.C: - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS()) - - case <-time.After(timeout): + ctx, _ := context.WithTimeout(context.Background(), timeout) + p, ok := e.ReadContext(ctx) + if !ok { t.Fatal("timed out waiting for packet") + return } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS()) } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, @@ -3391,23 +3393,20 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Enable forwarding which should stop router solicitations. s.SetForwarding(true) - select { - case <-e.C: + ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { // A single RS may have been sent before forwarding was enabled. - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) + if _, ok = e.ReadContext(ctx); ok { t.Fatal("Should not have sent more than one RS message") - case <-time.After(interval + defaultTimeout): } - case <-time.After(delay + defaultTimeout): } // Enabling forwarding again should do nothing. s.SetForwarding(true) - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after becoming a router") - case <-time.After(delay + defaultTimeout): } // Disable forwarding which should start router solicitations. @@ -3415,17 +3414,15 @@ func TestStopStartSolicitingRouters(t *testing.T) { waitForPkt(delay + defaultAsyncEventTimeout) waitForPkt(interval + defaultAsyncEventTimeout) waitForPkt(interval + defaultAsyncEventTimeout) - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") - case <-time.After(interval + defaultTimeout): } // Disabling forwarding again should do nothing. s.SetForwarding(false) - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after becoming a router") - case <-time.After(delay + defaultTimeout): } } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index dad288642..834fe9487 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1880,9 +1880,7 @@ func TestNICForwarding(t *testing.T) { Data: buf.ToVectorisedView(), }) - select { - case <-ep2.C: - default: + if _, ok := ep2.Read(); !ok { t.Fatal("Packet not forwarded") } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index f50604a8a..869c69a6d 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -623,10 +623,8 @@ func TestTransportForwarding(t *testing.T) { t.Fatalf("Write failed: %v", err) } - var p channel.PacketInfo - select { - case p = <-ep2.C: - default: + p, ok := ep2.Read() + if !ok { t.Fatal("Response packet not forwarded") } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 822907998..730ac4292 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -18,6 +18,7 @@ package context import ( "bytes" + "context" "testing" "time" @@ -215,11 +216,9 @@ func (c *Context) Stack() *stack.Stack { func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { c.t.Helper() - select { - case <-c.linkEP.C: + ctx, _ := context.WithTimeout(context.Background(), wait) + if _, ok := c.linkEP.ReadContext(ctx); ok { c.t.Fatal(errMsg) - - case <-time.After(wait): } } @@ -234,27 +233,27 @@ func (c *Context) CheckNoPacket(errMsg string) { // 2 seconds. func (c *Context) GetPacket() []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + c.t.Fatalf("Packet wasn't written out") + return nil + } - if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { - c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) - } + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - case <-time.After(2 * time.Second): - c.t.Fatalf("Packet wasn't written out") + if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { + c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) } - return nil + checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) + return b } // GetPacketNonBlocking reads a packet from the link layer endpoint @@ -263,20 +262,21 @@ func (c *Context) GetPacket() []byte { // nil immediately. func (c *Context) GetPacketNonBlocking() []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b - default: + p, ok := c.linkEP.Read() + if !ok { return nil } + + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + + checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) + return b } // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. @@ -484,23 +484,23 @@ func (c *Context) CreateV6Endpoint(v6only bool) { // and asserts that it is an IPv6 Packet with the expected src/dest addresses. func (c *Context) GetV6Packet() []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != ipv6.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) - } - b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) - copy(b, p.Pkt.Header.View()) - copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) - - checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) - return b - case <-time.After(2 * time.Second): + ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { c.t.Fatalf("Packet wasn't written out") + return nil + } + + if p.Proto != ipv6.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) } + b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) + copy(b, p.Pkt.Header.View()) + copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) - return nil + checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) + return b } // SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index c6927cfe3..f0ff3fe71 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,6 +16,7 @@ package udp_test import ( "bytes" + "context" "fmt" "math/rand" "testing" @@ -357,30 +358,29 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != flow.netProto() { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) - } - - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - - h := flow.header4Tuple(outgoing) - checkers := append( - checkers, - checker.SrcAddr(h.srcAddr.Addr), - checker.DstAddr(h.dstAddr.Addr), - checker.UDP(checker.DstPort(h.dstAddr.Port)), - ) - flow.checkerFn()(c.t, b, checkers...) - return b - - case <-time.After(2 * time.Second): + ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { c.t.Fatalf("Packet wasn't written out") + return nil } - return nil + if p.Proto != flow.netProto() { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) + } + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + + h := flow.header4Tuple(outgoing) + checkers = append( + checkers, + checker.SrcAddr(h.srcAddr.Addr), + checker.DstAddr(h.dstAddr.Addr), + checker.UDP(checker.DstPort(h.dstAddr.Port)), + ) + flow.checkerFn()(c.t, b, checkers...) + return b } // injectPacket creates a packet of the given flow and with the given payload, @@ -1541,48 +1541,50 @@ func TestV4UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - select { - case p := <-c.linkEP.C: + ctx, _ := context.WithTimeout(context.Background(), time.Second) + if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) - case <-time.After(1 * time.Second): - return } + return } - select { - case p := <-c.linkEP.C: - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) - if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } + // ICMP required. + ctx, _ := context.WithTimeout(context.Background(), time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + t.Fatalf("packet wasn't written out") + return + } - hdr := header.IPv4(pkt) - checker.IPv4(t, hdr, checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + var pkt []byte + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) + if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } - icmpPkt := header.ICMPv4(hdr.Payload()) - payloadIPHeader := header.IPv4(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize - } + hdr := header.IPv4(pkt) + checker.IPv4(t, hdr, checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4PortUnreachable))) - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + icmpPkt := header.ICMPv4(hdr.Payload()) + payloadIPHeader := header.IPv4(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize + } - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %d, want: %d", got, want) - } - case <-time.After(1 * time.Second): - t.Fatalf("packet wasn't written out") + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %d, want: %d", got, want) } }) } @@ -1615,47 +1617,49 @@ func TestV6UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - select { - case p := <-c.linkEP.C: + ctx, _ := context.WithTimeout(context.Background(), time.Second) + if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) - case <-time.After(1 * time.Second): - return } + return } - select { - case p := <-c.linkEP.C: - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) - if got, want := len(pkt), header.IPv6MinimumMTU; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } + // ICMP required. + ctx, _ := context.WithTimeout(context.Background(), time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + t.Fatalf("packet wasn't written out") + return + } + + var pkt []byte + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) + if got, want := len(pkt), header.IPv6MinimumMTU; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } - hdr := header.IPv6(pkt) - checker.IPv6(t, hdr, checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6PortUnreachable))) + hdr := header.IPv6(pkt) + checker.IPv6(t, hdr, checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6PortUnreachable))) - icmpPkt := header.ICMPv6(hdr.Payload()) - payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize - } - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) + icmpPkt := header.ICMPv6(hdr.Payload()) + payloadIPHeader := header.IPv6(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize + } + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %v, want: %v", got, want) - } - case <-time.After(1 * time.Second): - t.Fatalf("packet wasn't written out") + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %v, want: %v", got, want) } }) } -- cgit v1.2.3 From 431ff52768c2300e15cba609c2be4f507fd30d5b Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 28 Jan 2020 15:39:48 -0800 Subject: Update link address for senders of Neighbor Solicitations Update link address for senders of NDP Neighbor Solicitations when the NS contains an NDP Source Link Layer Address option. Tests: - ipv6.TestNeighorSolicitationWithSourceLinkLayerOption - ipv6.TestNeighorSolicitationWithInvalidSourceLinkLayerOption PiperOrigin-RevId: 292028553 --- pkg/tcpip/network/ipv6/icmp.go | 41 ++++++----- pkg/tcpip/network/ipv6/ndp_test.go | 135 +++++++++++++++++++++++++++++++++++++ 2 files changed, 161 insertions(+), 15 deletions(-) (limited to 'pkg/tcpip/network') diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 1c3410618..dc20c0fd7 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -137,21 +137,24 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P } ns := header.NDPNeighborSolicit(h.NDPPayload()) + it, err := ns.Options().Iter(true) + if err != nil { + // If we have a malformed NDP NS option, drop the packet. + received.Invalid.Increment() + return + } + targetAddr := ns.TargetAddress() s := r.Stack() rxNICID := r.NICID() - - isTentative, err := s.IsAddrTentative(rxNICID, targetAddr) - if err != nil { + if isTentative, err := s.IsAddrTentative(rxNICID, targetAddr); err != nil { // We will only get an error if rxNICID is unrecognized, // which should not happen. For now short-circuit this // packet. // // TODO(b/141002840): Handle this better? return - } - - if isTentative { + } else if isTentative { // If the target address is tentative and the source // of the packet is a unicast (specified) address, then // the source of the packet is attempting to perform @@ -185,6 +188,23 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P return } + // If the NS message has the source link layer option, update the link + // address cache with the link address for the sender of the message. + // + // TODO(b/148429853): Properly process the NS message and do Neighbor + // Unreachability Detection. + for { + opt, done, _ := it.Next() + if done { + break + } + + switch opt := opt.(type) { + case header.NDPSourceLinkLayerAddressOption: + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, opt.EthernetAddress()) + } + } + optsSerializer := header.NDPOptionsSerializer{ header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]), } @@ -211,15 +231,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P r.LocalAddress = targetAddr packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - // TODO(tamird/ghanan): there exists an explicit NDP option that is - // used to update the neighbor table with link addresses for a - // neighbor from an NS (see the Source Link Layer option RFC - // 4861 section 4.6.1 and section 7.2.3). - // - // Furthermore, the entirety of NDP handling here seems to be - // contradicted by RFC 4861. - e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) - // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) // // 7.1.2. Validation of Neighbor Advertisements diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index fe895b376..bd732f93f 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -70,6 +70,141 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack return s, ep } +// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving an +// NDP NS message with the Source Link Layer Address option results in a +// new entry in the link address cache for the sender of the message. +func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr0) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr0) + ns.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) + if err != nil { + t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) + } + if c != nil { + t.Errorf("got unexpected channel") + } + if linkAddr != linkAddr1 { + t.Errorf("got link address = %s, want = %s", linkAddr, linkAddr1) + } +} + +// TestNeighorSolicitationWithInvalidSourceLinkLayerOption tests that receiving +// an NDP NS message with an invalid Source Link Layer Address option does not +// result in a new entry in the link address cache for the sender of the +// message. +func TestNeighorSolicitationWithInvalidSourceLinkLayerOption(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + optsBuf []byte + }{ + { + name: "Too Small", + optsBuf: []byte{1, 1, 1, 2, 3, 4, 5}, + }, + { + name: "Invalid Length", + optsBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr0) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr0) + opts := ns.Options() + copy(opts, test.optsBuf) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + + linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) + } + if c == nil { + t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + } + if linkAddr != "" { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (%s, _, ), want = ('', _, _)", nicID, lladdr1, lladdr0, ProtocolNumber, linkAddr) + } + }) + } +} + // TestHopLimitValidation is a test that makes sure that NDP packets are only // received if their IP header's hop limit is set to 255. func TestHopLimitValidation(t *testing.T) { -- cgit v1.2.3