diff options
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r-- | pkg/tcpip/network/arp/arp_test.go | 101 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 391 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 979 |
3 files changed, 467 insertions, 1004 deletions
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 24357e15d..c8b9ff9fc 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -17,7 +17,6 @@ package arp_test import ( "context" "fmt" - "strconv" "testing" "time" @@ -155,7 +154,7 @@ type testContext struct { nudDisp *arpDispatcher } -func newTestContext(t *testing.T, useNeighborCache bool) *testContext { +func newTestContext(t *testing.T) *testContext { c := stack.DefaultNUDConfigurations() // Transition from Reachable to Stale almost immediately to test if receiving // probes refreshes positive reachability. @@ -173,7 +172,6 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, NUDConfigs: c, NUDDisp: &d, - UseNeighborCache: useNeighborCache, }) ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) @@ -191,15 +189,6 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext { if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { t.Fatalf("AddAddress for ipv4 failed: %v", err) } - if !useNeighborCache { - // The remote address needs to be assigned to the NIC so we can receive and - // verify outgoing ARP packets. The neighbor cache isn't concerned with - // this; the tests that use linkAddrCache expect the ARP responses to be - // received by the same NIC. - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, remoteAddr); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) - } - } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, @@ -217,86 +206,8 @@ func (c *testContext) cleanup() { c.linkEP.Close() } -func TestDirectRequest(t *testing.T) { - c := newTestContext(t, false /* useNeighborCache */) - defer c.cleanup() - - const senderMAC = "\x01\x02\x03\x04\x05\x06" - const senderIPv4 = "\x0a\x00\x00\x02" - - v := make(buffer.View, header.ARPSize) - h := header.ARP(v) - h.SetIPv4OverEthernet() - h.SetOp(header.ARPRequest) - copy(h.HardwareAddressSender(), senderMAC) - copy(h.ProtocolAddressSender(), senderIPv4) - - inject := func(addr tcpip.Address) { - copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - })) - } - - for i, address := range []tcpip.Address{stackAddr, remoteAddr} { - t.Run(strconv.Itoa(i), func(t *testing.T) { - expectedPacketsReceived := c.s.Stats().ARP.PacketsReceived.Value() + 1 - expectedRequestsReceived := c.s.Stats().ARP.RequestsReceived.Value() + 1 - expectedRepliesSent := c.s.Stats().ARP.OutgoingRepliesSent.Value() + 1 - - inject(address) - pi, _ := c.linkEP.ReadContext(context.Background()) - if pi.Proto != arp.ProtocolNumber { - t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) - } - rep := header.ARP(pi.Pkt.NetworkHeader().View()) - if !rep.IsValid() { - t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) - } - if got := rep.Op(); got != header.ARPReply { - t.Fatalf("got Op = %d, want = %d", got, header.ARPReply) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { - t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { - t.Errorf("got ProtocolAddressSender = %s, want = %s", got, want) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want { - t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { - t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want) - } - - if got := c.s.Stats().ARP.PacketsReceived.Value(); got != expectedPacketsReceived { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, expectedPacketsReceived) - } - if got := c.s.Stats().ARP.RequestsReceived.Value(); got != expectedRequestsReceived { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, expectedRequestsReceived) - } - if got := c.s.Stats().ARP.OutgoingRepliesSent.Value(); got != expectedRepliesSent { - t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, expectedRepliesSent) - } - }) - } - - inject(unknownAddr) - // Sleep tests are gross, but this will only potentially flake - // if there's a bug. If there is no bug this will reliably - // succeed. - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - if pkt, ok := c.linkEP.ReadContext(ctx); ok { - t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) - } - if got := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.RequestsReceivedUnKnownTargetAddress.Value() = %d, want = 1", got) - } -} - func TestMalformedPacket(t *testing.T) { - c := newTestContext(t, false) + c := newTestContext(t) defer c.cleanup() v := make(buffer.View, header.ARPSize) @@ -315,7 +226,7 @@ func TestMalformedPacket(t *testing.T) { } func TestDisabledEndpoint(t *testing.T) { - c := newTestContext(t, false) + c := newTestContext(t) defer c.cleanup() ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber) @@ -340,7 +251,7 @@ func TestDisabledEndpoint(t *testing.T) { } func TestDirectReply(t *testing.T) { - c := newTestContext(t, false) + c := newTestContext(t) defer c.cleanup() const senderMAC = "\x01\x02\x03\x04\x05\x06" @@ -370,8 +281,8 @@ func TestDirectReply(t *testing.T) { } } -func TestDirectRequestWithNeighborCache(t *testing.T) { - c := newTestContext(t, true /* useNeighborCache */) +func TestDirectRequest(t *testing.T) { + c := newTestContext(t) defer c.cleanup() tests := []struct { diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 4dbfb80da..69c1e4bea 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -175,168 +175,9 @@ func handleICMPInIPv6(ep stack.NetworkEndpoint, src, dst tcpip.Address, icmp hea } func TestICMPCounts(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - UseNeighborCache: test.useNeighborCache, - }) - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - - types := []struct { - typ header.ICMPv6Type - size int - extraData []byte - }{ - { - typ: header.ICMPv6DstUnreachable, - size: header.ICMPv6DstUnreachableMinimumSize, - }, - { - typ: header.ICMPv6PacketTooBig, - size: header.ICMPv6PacketTooBigMinimumSize, - }, - { - typ: header.ICMPv6TimeExceeded, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6ParamProblem, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6EchoRequest, - size: header.ICMPv6EchoMinimumSize, - }, - { - typ: header.ICMPv6EchoReply, - size: header.ICMPv6EchoMinimumSize, - }, - { - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - }, - { - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - }, - { - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - }, - { - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6MulticastListenerQuery, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: header.ICMPv6MulticastListenerReport, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: header.ICMPv6MulticastListenerDone, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: 255, /* Unrecognized */ - size: 50, - }, - } - - for _, typ := range types { - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - handleICMPInIPv6(ep, lladdr1, lladdr0, icmp) - } - - // Construct an empty ICMP packet so that - // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleICMPInIPv6(ep, lladdr1, lladdr0, header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) - - icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived - visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { - if got, want := s.Value(), uint64(1); got != want { - t.Errorf("got %s = %d, want = %d", name, got, want) - } - }) - if t.Failed() { - t.Logf("stats:\n%+v", s.Stats()) - } - }) - } -} - -func TestICMPCountsWithNeighborCache(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) @@ -770,135 +611,116 @@ func TestICMPChecksumValidationSimple(t *testing.T) { }, } - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } + for _, typ := range types { + for _, isRouter := range []bool{false, true} { + name := typ.name + if isRouter { + name += " (Router)" + } + t.Run(name, func(t *testing.T) { + e := channel.New(0, 1280, linkAddr0) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, typ := range types { - for _, isRouter := range []bool{false, true} { - name := typ.name - if isRouter { - name += " (Router)" + // Indicate that resolution for link layer addresses is required to + // send packets over this link. This is needed so the NIC knows to + // allocate a neighbor table. + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + if isRouter { + // Enabling forwarding makes the stack act as a router. + s.SetForwarding(ProtocolNumber, true) + } + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) + } + + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}, + ) + } + + handleIPv6Payload := func(checksum bool) { + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + if checksum { + icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) } - t.Run(name, func(t *testing.T) { - e := channel.New(0, 1280, linkAddr0) - - // Indicate that resolution for link layer addresses is required to - // send packets over this link. This is needed so the NIC knows to - // allocate a neighbor table. - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: test.useNeighborCache, - }) - if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) - } - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - handleIPv6Payload := func(checksum bool) { - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - if checksum { - icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) - } - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), - }) - e.InjectInbound(ProtocolNumber, pkt) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - routerOnly := stats.RouterOnlyPacketsDroppedByHost - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := routerOnly.Value(); got != 0 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Router only count should not have increased. - if got := routerOnly.Value(); got != 0 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - if !isRouter && typ.routerOnly && test.useNeighborCache { - // Router only count should have increased. - if got := routerOnly.Value(); got != 1 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got) - } - } + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), + }) + e.InjectInbound(ProtocolNumber, pkt) } - } - }) + + stats := s.Stats().ICMP.V6.PacketsReceived + invalid := stats.Invalid + routerOnly := stats.RouterOnlyPacketsDroppedByHost + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := routerOnly.Value(); got != 0 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Router only count should not have increased. + if got := routerOnly.Value(); got != 0 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + if !isRouter && typ.routerOnly { + // Router only count should have increased. + if got := routerOnly.Value(); got != 1 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got) + } + } + }) + } } } @@ -1762,7 +1584,6 @@ func TestCallsToNeighborCache(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - UseNeighborCache: true, }) { if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 104fe2139..5ff247653 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -33,13 +33,12 @@ import ( // setupStackAndEndpoint creates a stack with a single NIC with a link-local // address llladdr and an IPv6 endpoint to a remote with link-local address // rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeighborCache bool) (*stack.Stack, stack.NetworkEndpoint) { +func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - UseNeighborCache: useNeighborCache, }) if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { @@ -237,107 +236,6 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { Data: hdr.View().ToVectorisedView(), })) - ch := make(chan stack.LinkResolutionResult, 1) - err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) { - ch <- r - }) - - wantInvalid := uint64(0) - wantSucccess := true - if len(test.expectedLinkAddr) == 0 { - wantInvalid = 1 - wantSucccess = false - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{}) - } - } else { - if err != nil { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err) - } - } - - if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { - t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) - } - if got := invalid.Value(); got != wantInvalid { - t.Errorf("got invalid = %d, want = %d", got, wantInvalid) - } - }) - } -} - -// TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests -// that receiving a valid NDP NS message with the Source Link Layer Address -// option results in a new entry in the link address cache for the sender of -// the message. -func TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - optsBuf []byte - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "Valid", - optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7}, - expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", - }, - { - name: "Too Small", - optsBuf: []byte{1, 1, 2, 3, 4, 5, 6}, - }, - { - name: "Invalid Length", - optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: true, - }) - e := channel.New(0, 1280, linkAddr0) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.MessageBody()) - ns.SetTargetAddress(lladdr0) - opts := ns.Options() - copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - - invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - neighbors, err := s.Neighbors(nicID, ProtocolNumber) if err != nil { t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) @@ -392,20 +290,6 @@ func TestNeighborSolicitationResponse(t *testing.T) { remoteLinkAddr0 := linkAddr1 remoteLinkAddr1 := linkAddr2 - stacks := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - tests := []struct { name string nsOpts header.NDPOptionsSerializer @@ -564,229 +448,44 @@ func TestNeighborSolicitationResponse(t *testing.T) { }, } - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: stackTyp.useNeighborCache, - }) - e := channel.New(1, 1280, nicLinkAddr) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.MessageBody()) - ns.SetTargetAddress(nicAddr) - opts := ns.Options() - opts.Serialize(test.nsOpts) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: test.nsSrc, - DstAddr: test.nsDst, - }) - - invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - if test.nsInvalid { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - - if p, got := e.Read(); got { - t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) - } - - // If we expected the NS to be invalid, we have nothing else to check. - return - } - - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - if test.performsLinkResolution { - p, got := e.ReadContext(context.Background()) - if !got { - t.Fatal("expected an NDP NS response") - } - - respNSDst := header.SolicitedNodeAddr(test.nsSrc) - var want stack.RouteInfo - want.NetProto = ProtocolNumber - want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst) - if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" { - t.Errorf("route info mismatch (-want +got):\n%s", diff) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(nicAddr), - checker.DstAddr(respNSDst), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(test.nsSrc), - checker.NDPNSOptions([]header.NDPOption{ - header.NDPSourceLinkLayerAddressOption(nicLinkAddr), - }), - )) - - ser := header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - } - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(test.nsSrc) - na.Options().Serialize(ser) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: test.nsSrc, - DstAddr: nicAddr, - }) - e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - p, got := e.ReadContext(context.Background()) - if !got { - t.Fatal("expected an NDP NA response") - } - - if p.Route.LocalAddress != test.naSrc { - t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc) - } - if p.Route.LocalLinkAddress != nicLinkAddr { - t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) - } - if p.Route.RemoteAddress != test.naDst { - t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) - } - if p.Route.RemoteLinkAddress != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.naSrc), - checker.DstAddr(test.naDst), - checker.TTL(header.NDPHopLimit), - checker.NDPNA( - checker.NDPNASolicitedFlag(test.naSolicited), - checker.NDPNATargetAddress(nicAddr), - checker.NDPNAOptions([]header.NDPOption{ - header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), - }), - )) - }) - } - }) - } -} - -// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a -// valid NDP NA message with the Target Link Layer Address option results in a -// new entry in the link address cache for the target of the message. -func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - optsBuf []byte - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "Valid", - optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7}, - expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", - }, - { - name: "Too Small", - optsBuf: []byte{2, 1, 2, 3, 4, 5, 6}, - }, - { - name: "Invalid Length", - optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7}, - }, - { - name: "Multiple", - optsBuf: []byte{ - 2, 1, 2, 3, 4, 5, 6, 7, - 2, 1, 2, 3, 4, 5, 6, 8, - }, - }, - } - for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseLinkAddrCache: true, }) - e := channel.New(0, 1280, linkAddr0) + e := channel.New(1, 1280, nicLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) } - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.MessageBody()) - ns.SetTargetAddress(lladdr1) + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) + ns.SetTargetAddress(nicAddr) opts := ns.Options() - copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + opts.Serialize(test.nsOpts) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), TransportProtocol: header.ICMPv6ProtocolNumber, HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + SrcAddr: test.nsSrc, + DstAddr: test.nsDst, }) invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid @@ -796,44 +495,116 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) - ch := make(chan stack.LinkResolutionResult, 1) - err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) { - ch <- r - }) + if test.nsInvalid { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } - wantInvalid := uint64(0) - wantSucccess := true - if len(test.expectedLinkAddr) == 0 { - wantInvalid = 1 - wantSucccess = false - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{}) + if p, got := e.Read(); got { + t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) } - } else { - if err != nil { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err) + + // If we expected the NS to be invalid, we have nothing else to check. + return + } + + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + if test.performsLinkResolution { + p, got := e.ReadContext(context.Background()) + if !got { + t.Fatal("expected an NDP NS response") + } + + respNSDst := header.SolicitedNodeAddr(test.nsSrc) + var want stack.RouteInfo + want.NetProto = ProtocolNumber + want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst) + if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" { + t.Errorf("route info mismatch (-want +got):\n%s", diff) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(nicAddr), + checker.DstAddr(respNSDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(test.nsSrc), + checker.NDPNSOptions([]header.NDPOption{ + header.NDPSourceLinkLayerAddressOption(nicLinkAddr), + }), + )) + + ser := header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), } + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(pkt.MessageBody()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(true) + na.SetTargetAddress(test.nsSrc) + na.Options().Serialize(ser) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: test.nsSrc, + DstAddr: nicAddr, + }) + e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) } - if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { - t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) + p, got := e.ReadContext(context.Background()) + if !got { + t.Fatal("expected an NDP NA response") } - if got := invalid.Value(); got != wantInvalid { - t.Errorf("got invalid = %d, want = %d", got, wantInvalid) + + if p.Route.LocalAddress != test.naSrc { + t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc) } + if p.Route.LocalLinkAddress != nicLinkAddr { + t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) + } + if p.Route.RemoteAddress != test.naDst { + t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) + } + if p.Route.RemoteLinkAddress != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.naSrc), + checker.DstAddr(test.naDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNA( + checker.NDPNASolicitedFlag(test.naSolicited), + checker.NDPNATargetAddress(nicAddr), + checker.NDPNAOptions([]header.NDPOption{ + header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), + }), + )) }) } } -// TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests -// that receiving a valid NDP NA message with the Target Link Layer Address -// option does not result in a new entry in the neighbor cache for the target -// of the message. -func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) { +// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a +// valid NDP NA message with the Target Link Layer Address option does not +// result in a new entry in the neighbor cache for the target of the message. +func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { const nicID = 1 tests := []struct { @@ -867,7 +638,6 @@ func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *tes t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: true, }) e := channel.New(0, 1280, linkAddr0) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -944,235 +714,216 @@ func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *tes } func TestNDPValidation(t *testing.T) { - stacks := []struct { - name string - useNeighborCache bool + setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { + t.Helper() + + // Create a stack with the assigned link-local address lladdr0 + // and an endpoint to lladdr1. + s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) + + return s, ep + } + + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { + var extHdrs header.IPv6ExtHdrSerializer + if atomicFragment { + extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{}) + } + extHdrsLen := extHdrs.Length() + + ip := buffer.NewView(header.IPv6MinimumSize + extHdrsLen) + header.IPv6(ip).Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(payload) + extHdrsLen), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: hopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + ExtensionHeaders: extHdrs, + }) + vv := ip.ToVectorisedView() + vv.AppendView(payload) + ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + } + + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + + var sllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) + + types := []struct { + name string + typ header.ICMPv6Type + size int + extraData []byte + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + routerOnly bool }{ { - name: "linkAddrCache", - useNeighborCache: false, + name: "RouterSolicit", + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + routerOnly: true, }, { - name: "neighborCache", - useNeighborCache: true, + name: "RouterAdvert", + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + name: "NeighborSolicit", + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + extraData: sllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + name: "NeighborAdvert", + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + name: "RedirectMsg", + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, }, } - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - // Create a stack with the assigned link-local address lladdr0 - // and an endpoint to lladdr1. - s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache) + subTests := []struct { + name string + atomicFragment bool + hopLimit uint8 + code header.ICMPv6Code + valid bool + }{ + { + name: "Valid", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 0, + valid: true, + }, + { + name: "Fragmented", + atomicFragment: true, + hopLimit: header.NDPHopLimit, + code: 0, + valid: false, + }, + { + name: "Invalid hop limit", + atomicFragment: false, + hopLimit: header.NDPHopLimit - 1, + code: 0, + valid: false, + }, + { + name: "Invalid ICMPv6 code", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 1, + valid: false, + }, + } - return s, ep + for _, typ := range types { + for _, isRouter := range []bool{false, true} { + name := typ.name + if isRouter { + name += " (Router)" } - handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { - var extHdrs header.IPv6ExtHdrSerializer - if atomicFragment { - extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{}) - } - extHdrsLen := extHdrs.Length() + t.Run(name, func(t *testing.T) { + for _, test := range subTests { + t.Run(test.name, func(t *testing.T) { + s, ep := setup(t) - ip := buffer.NewView(header.IPv6MinimumSize + extHdrsLen) - header.IPv6(ip).Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload) + extHdrsLen), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: hopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - ExtensionHeaders: extHdrs, - }) - vv := ip.ToVectorisedView() - vv.AppendView(payload) - ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } + if isRouter { + // Enabling forwarding makes the stack act as a router. + s.SetForwarding(ProtocolNumber, true) + } - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) + stats := s.Stats().ICMP.V6.PacketsReceived + invalid := stats.Invalid + routerOnly := stats.RouterOnlyPacketsDroppedByHost + typStat := typ.statCounter(stats) - var sllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetCode(test.code) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - types := []struct { - name string - typ header.ICMPv6Type - size int - extraData []byte - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - routerOnly bool - }{ - { - name: "RouterSolicit", - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }, - routerOnly: true, - }, - { - name: "RouterAdvert", - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }, - }, - { - name: "NeighborSolicit", - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - extraData: sllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }, - }, - { - name: "NeighborAdvert", - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }, - }, - { - name: "RedirectMsg", - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }, - }, - } + // Rx count of the NDP message should initially be 0. + if got := typStat.Value(); got != 0 { + t.Errorf("got %s = %d, want = 0", typ.name, got) + } - subTests := []struct { - name string - atomicFragment bool - hopLimit uint8 - code header.ICMPv6Code - valid bool - }{ - { - name: "Valid", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 0, - valid: true, - }, - { - name: "Fragmented", - atomicFragment: true, - hopLimit: header.NDPHopLimit, - code: 0, - valid: false, - }, - { - name: "Invalid hop limit", - atomicFragment: false, - hopLimit: header.NDPHopLimit - 1, - code: 0, - valid: false, - }, - { - name: "Invalid ICMPv6 code", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 1, - valid: false, - }, - } + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } - for _, typ := range types { - for _, isRouter := range []bool{false, true} { - name := typ.name - if isRouter { - name += " (Router)" - } + // RouterOnlyPacketsReceivedByHost count should initially be 0. + if got := routerOnly.Value(); got != 0 { + t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + + if t.Failed() { + t.FailNow() + } + + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) + + // Rx count of the NDP packet should have increased. + if got := typStat.Value(); got != 1 { + t.Errorf("got %s = %d, want = 1", typ.name, got) + } - t.Run(name, func(t *testing.T) { - for _, test := range subTests { - t.Run(test.name, func(t *testing.T) { - s, ep := setup(t) - - if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - routerOnly := stats.RouterOnlyPacketsDroppedByHost - typStat := typ.statCounter(stats) - - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - - // Rx count of the NDP message should initially be 0. - if got := typStat.Value(); got != 0 { - t.Errorf("got %s = %d, want = 0", typ.name, got) - } - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - - // RouterOnlyPacketsReceivedByHost count should initially be 0. - if got := routerOnly.Value(); got != 0 { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - - if t.Failed() { - t.FailNow() - } - - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) - - // Rx count of the NDP packet should have increased. - if got := typStat.Value(); got != 1 { - t.Errorf("got %s = %d, want = 1", typ.name, got) - } - - want := uint64(0) - if !test.valid { - // Invalid count should have increased. - want = 1 - } - if got := invalid.Value(); got != want { - t.Errorf("got invalid = %d, want = %d", got, want) - } - - want = 0 - if test.valid && !isRouter && typ.routerOnly { - // RouterOnlyPacketsReceivedByHost count should have increased. - want = 1 - } - if got := routerOnly.Value(); got != want { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) - } - - }) + want := uint64(0) + if !test.valid { + // Invalid count should have increased. + want = 1 } + if got := invalid.Value(); got != want { + t.Errorf("got invalid = %d, want = %d", got, want) + } + + want = 0 + if test.valid && !isRouter && typ.routerOnly { + // RouterOnlyPacketsReceivedByHost count should have increased. + want = 1 + } + if got := routerOnly.Value(); got != want { + t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) + } + }) } - } - }) + }) + } } - } // TestNeighborAdvertisementValidation tests that the NIC validates received @@ -1218,7 +969,6 @@ func TestNeighborAdvertisementValidation(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: true, }) e := channel.New(0, header.IPv6MinimumMTU, linkAddr0) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -1291,20 +1041,6 @@ func TestNeighborAdvertisementValidation(t *testing.T) { // NDP Router Advertisement packets, it validates the Router Advertisement // properly before handling them. func TestRouterAdvertValidation(t *testing.T) { - stacks := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - tests := []struct { name string src tcpip.Address @@ -1426,67 +1162,62 @@ func TestRouterAdvertValidation(t *testing.T) { }, } - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: stackTyp.useNeighborCache, - }) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } - icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(test.code) - copy(pkt.MessageBody(), test.ndpPayload) - payloadLength := hdr.UsedLength() - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) + icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(test.code) + copy(pkt.MessageBody(), test.ndpPayload) + payloadLength := hdr.UsedLength() + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - rxRA := stats.RouterAdvert + stats := s.Stats().ICMP.V6.PacketsReceived + invalid := stats.Invalid + rxRA := stats.RouterAdvert - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := rxRA.Value(); got != 0 { - t.Fatalf("got rxRA = %d, want = 0", got) - } + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) - if got := rxRA.Value(); got != 1 { - t.Fatalf("got rxRA = %d, want = 1", got) - } + if got := rxRA.Value(); got != 1 { + t.Fatalf("got rxRA = %d, want = 1", got) + } - if test.expectedSuccess { - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - } else { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - } - }) + if test.expectedSuccess { + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + } else { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } } }) } |