diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/network/arp/arp_test.go | 101 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 391 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 979 | ||||
-rw-r--r-- | pkg/tcpip/stack/BUILD | 15 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 640 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache.go | 359 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache_test.go | 291 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 773 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache.go | 38 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry_test.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 64 | ||||
-rw-r--r-- | pkg/tcpip/stack/nud_test.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/link_resolution_test.go | 145 |
17 files changed, 1199 insertions, 2657 deletions
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 24357e15d..c8b9ff9fc 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -17,7 +17,6 @@ package arp_test import ( "context" "fmt" - "strconv" "testing" "time" @@ -155,7 +154,7 @@ type testContext struct { nudDisp *arpDispatcher } -func newTestContext(t *testing.T, useNeighborCache bool) *testContext { +func newTestContext(t *testing.T) *testContext { c := stack.DefaultNUDConfigurations() // Transition from Reachable to Stale almost immediately to test if receiving // probes refreshes positive reachability. @@ -173,7 +172,6 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, NUDConfigs: c, NUDDisp: &d, - UseNeighborCache: useNeighborCache, }) ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) @@ -191,15 +189,6 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext { if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { t.Fatalf("AddAddress for ipv4 failed: %v", err) } - if !useNeighborCache { - // The remote address needs to be assigned to the NIC so we can receive and - // verify outgoing ARP packets. The neighbor cache isn't concerned with - // this; the tests that use linkAddrCache expect the ARP responses to be - // received by the same NIC. - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, remoteAddr); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) - } - } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, @@ -217,86 +206,8 @@ func (c *testContext) cleanup() { c.linkEP.Close() } -func TestDirectRequest(t *testing.T) { - c := newTestContext(t, false /* useNeighborCache */) - defer c.cleanup() - - const senderMAC = "\x01\x02\x03\x04\x05\x06" - const senderIPv4 = "\x0a\x00\x00\x02" - - v := make(buffer.View, header.ARPSize) - h := header.ARP(v) - h.SetIPv4OverEthernet() - h.SetOp(header.ARPRequest) - copy(h.HardwareAddressSender(), senderMAC) - copy(h.ProtocolAddressSender(), senderIPv4) - - inject := func(addr tcpip.Address) { - copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - })) - } - - for i, address := range []tcpip.Address{stackAddr, remoteAddr} { - t.Run(strconv.Itoa(i), func(t *testing.T) { - expectedPacketsReceived := c.s.Stats().ARP.PacketsReceived.Value() + 1 - expectedRequestsReceived := c.s.Stats().ARP.RequestsReceived.Value() + 1 - expectedRepliesSent := c.s.Stats().ARP.OutgoingRepliesSent.Value() + 1 - - inject(address) - pi, _ := c.linkEP.ReadContext(context.Background()) - if pi.Proto != arp.ProtocolNumber { - t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) - } - rep := header.ARP(pi.Pkt.NetworkHeader().View()) - if !rep.IsValid() { - t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) - } - if got := rep.Op(); got != header.ARPReply { - t.Fatalf("got Op = %d, want = %d", got, header.ARPReply) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { - t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { - t.Errorf("got ProtocolAddressSender = %s, want = %s", got, want) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want { - t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { - t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want) - } - - if got := c.s.Stats().ARP.PacketsReceived.Value(); got != expectedPacketsReceived { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, expectedPacketsReceived) - } - if got := c.s.Stats().ARP.RequestsReceived.Value(); got != expectedRequestsReceived { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, expectedRequestsReceived) - } - if got := c.s.Stats().ARP.OutgoingRepliesSent.Value(); got != expectedRepliesSent { - t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, expectedRepliesSent) - } - }) - } - - inject(unknownAddr) - // Sleep tests are gross, but this will only potentially flake - // if there's a bug. If there is no bug this will reliably - // succeed. - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - if pkt, ok := c.linkEP.ReadContext(ctx); ok { - t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) - } - if got := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.RequestsReceivedUnKnownTargetAddress.Value() = %d, want = 1", got) - } -} - func TestMalformedPacket(t *testing.T) { - c := newTestContext(t, false) + c := newTestContext(t) defer c.cleanup() v := make(buffer.View, header.ARPSize) @@ -315,7 +226,7 @@ func TestMalformedPacket(t *testing.T) { } func TestDisabledEndpoint(t *testing.T) { - c := newTestContext(t, false) + c := newTestContext(t) defer c.cleanup() ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber) @@ -340,7 +251,7 @@ func TestDisabledEndpoint(t *testing.T) { } func TestDirectReply(t *testing.T) { - c := newTestContext(t, false) + c := newTestContext(t) defer c.cleanup() const senderMAC = "\x01\x02\x03\x04\x05\x06" @@ -370,8 +281,8 @@ func TestDirectReply(t *testing.T) { } } -func TestDirectRequestWithNeighborCache(t *testing.T) { - c := newTestContext(t, true /* useNeighborCache */) +func TestDirectRequest(t *testing.T) { + c := newTestContext(t) defer c.cleanup() tests := []struct { diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 4dbfb80da..69c1e4bea 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -175,168 +175,9 @@ func handleICMPInIPv6(ep stack.NetworkEndpoint, src, dst tcpip.Address, icmp hea } func TestICMPCounts(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - UseNeighborCache: test.useNeighborCache, - }) - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - - types := []struct { - typ header.ICMPv6Type - size int - extraData []byte - }{ - { - typ: header.ICMPv6DstUnreachable, - size: header.ICMPv6DstUnreachableMinimumSize, - }, - { - typ: header.ICMPv6PacketTooBig, - size: header.ICMPv6PacketTooBigMinimumSize, - }, - { - typ: header.ICMPv6TimeExceeded, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6ParamProblem, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6EchoRequest, - size: header.ICMPv6EchoMinimumSize, - }, - { - typ: header.ICMPv6EchoReply, - size: header.ICMPv6EchoMinimumSize, - }, - { - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - }, - { - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - }, - { - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - }, - { - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6MulticastListenerQuery, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: header.ICMPv6MulticastListenerReport, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: header.ICMPv6MulticastListenerDone, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: 255, /* Unrecognized */ - size: 50, - }, - } - - for _, typ := range types { - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - handleICMPInIPv6(ep, lladdr1, lladdr0, icmp) - } - - // Construct an empty ICMP packet so that - // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleICMPInIPv6(ep, lladdr1, lladdr0, header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) - - icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived - visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { - if got, want := s.Value(), uint64(1); got != want { - t.Errorf("got %s = %d, want = %d", name, got, want) - } - }) - if t.Failed() { - t.Logf("stats:\n%+v", s.Stats()) - } - }) - } -} - -func TestICMPCountsWithNeighborCache(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) @@ -770,135 +611,116 @@ func TestICMPChecksumValidationSimple(t *testing.T) { }, } - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } + for _, typ := range types { + for _, isRouter := range []bool{false, true} { + name := typ.name + if isRouter { + name += " (Router)" + } + t.Run(name, func(t *testing.T) { + e := channel.New(0, 1280, linkAddr0) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, typ := range types { - for _, isRouter := range []bool{false, true} { - name := typ.name - if isRouter { - name += " (Router)" + // Indicate that resolution for link layer addresses is required to + // send packets over this link. This is needed so the NIC knows to + // allocate a neighbor table. + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + if isRouter { + // Enabling forwarding makes the stack act as a router. + s.SetForwarding(ProtocolNumber, true) + } + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) + } + + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}, + ) + } + + handleIPv6Payload := func(checksum bool) { + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + if checksum { + icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) } - t.Run(name, func(t *testing.T) { - e := channel.New(0, 1280, linkAddr0) - - // Indicate that resolution for link layer addresses is required to - // send packets over this link. This is needed so the NIC knows to - // allocate a neighbor table. - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: test.useNeighborCache, - }) - if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) - } - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - handleIPv6Payload := func(checksum bool) { - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - if checksum { - icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) - } - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), - }) - e.InjectInbound(ProtocolNumber, pkt) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - routerOnly := stats.RouterOnlyPacketsDroppedByHost - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := routerOnly.Value(); got != 0 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Router only count should not have increased. - if got := routerOnly.Value(); got != 0 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - if !isRouter && typ.routerOnly && test.useNeighborCache { - // Router only count should have increased. - if got := routerOnly.Value(); got != 1 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got) - } - } + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), + }) + e.InjectInbound(ProtocolNumber, pkt) } - } - }) + + stats := s.Stats().ICMP.V6.PacketsReceived + invalid := stats.Invalid + routerOnly := stats.RouterOnlyPacketsDroppedByHost + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := routerOnly.Value(); got != 0 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Router only count should not have increased. + if got := routerOnly.Value(); got != 0 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + if !isRouter && typ.routerOnly { + // Router only count should have increased. + if got := routerOnly.Value(); got != 1 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got) + } + } + }) + } } } @@ -1762,7 +1584,6 @@ func TestCallsToNeighborCache(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - UseNeighborCache: true, }) { if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 104fe2139..5ff247653 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -33,13 +33,12 @@ import ( // setupStackAndEndpoint creates a stack with a single NIC with a link-local // address llladdr and an IPv6 endpoint to a remote with link-local address // rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeighborCache bool) (*stack.Stack, stack.NetworkEndpoint) { +func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - UseNeighborCache: useNeighborCache, }) if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { @@ -237,107 +236,6 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { Data: hdr.View().ToVectorisedView(), })) - ch := make(chan stack.LinkResolutionResult, 1) - err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) { - ch <- r - }) - - wantInvalid := uint64(0) - wantSucccess := true - if len(test.expectedLinkAddr) == 0 { - wantInvalid = 1 - wantSucccess = false - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{}) - } - } else { - if err != nil { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err) - } - } - - if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { - t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) - } - if got := invalid.Value(); got != wantInvalid { - t.Errorf("got invalid = %d, want = %d", got, wantInvalid) - } - }) - } -} - -// TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests -// that receiving a valid NDP NS message with the Source Link Layer Address -// option results in a new entry in the link address cache for the sender of -// the message. -func TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - optsBuf []byte - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "Valid", - optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7}, - expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", - }, - { - name: "Too Small", - optsBuf: []byte{1, 1, 2, 3, 4, 5, 6}, - }, - { - name: "Invalid Length", - optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: true, - }) - e := channel.New(0, 1280, linkAddr0) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.MessageBody()) - ns.SetTargetAddress(lladdr0) - opts := ns.Options() - copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - - invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - neighbors, err := s.Neighbors(nicID, ProtocolNumber) if err != nil { t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) @@ -392,20 +290,6 @@ func TestNeighborSolicitationResponse(t *testing.T) { remoteLinkAddr0 := linkAddr1 remoteLinkAddr1 := linkAddr2 - stacks := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - tests := []struct { name string nsOpts header.NDPOptionsSerializer @@ -564,229 +448,44 @@ func TestNeighborSolicitationResponse(t *testing.T) { }, } - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: stackTyp.useNeighborCache, - }) - e := channel.New(1, 1280, nicLinkAddr) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.MessageBody()) - ns.SetTargetAddress(nicAddr) - opts := ns.Options() - opts.Serialize(test.nsOpts) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: test.nsSrc, - DstAddr: test.nsDst, - }) - - invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - if test.nsInvalid { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - - if p, got := e.Read(); got { - t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) - } - - // If we expected the NS to be invalid, we have nothing else to check. - return - } - - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - if test.performsLinkResolution { - p, got := e.ReadContext(context.Background()) - if !got { - t.Fatal("expected an NDP NS response") - } - - respNSDst := header.SolicitedNodeAddr(test.nsSrc) - var want stack.RouteInfo - want.NetProto = ProtocolNumber - want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst) - if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" { - t.Errorf("route info mismatch (-want +got):\n%s", diff) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(nicAddr), - checker.DstAddr(respNSDst), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(test.nsSrc), - checker.NDPNSOptions([]header.NDPOption{ - header.NDPSourceLinkLayerAddressOption(nicLinkAddr), - }), - )) - - ser := header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - } - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(test.nsSrc) - na.Options().Serialize(ser) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: test.nsSrc, - DstAddr: nicAddr, - }) - e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - p, got := e.ReadContext(context.Background()) - if !got { - t.Fatal("expected an NDP NA response") - } - - if p.Route.LocalAddress != test.naSrc { - t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc) - } - if p.Route.LocalLinkAddress != nicLinkAddr { - t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) - } - if p.Route.RemoteAddress != test.naDst { - t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) - } - if p.Route.RemoteLinkAddress != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.naSrc), - checker.DstAddr(test.naDst), - checker.TTL(header.NDPHopLimit), - checker.NDPNA( - checker.NDPNASolicitedFlag(test.naSolicited), - checker.NDPNATargetAddress(nicAddr), - checker.NDPNAOptions([]header.NDPOption{ - header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), - }), - )) - }) - } - }) - } -} - -// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a -// valid NDP NA message with the Target Link Layer Address option results in a -// new entry in the link address cache for the target of the message. -func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - optsBuf []byte - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "Valid", - optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7}, - expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", - }, - { - name: "Too Small", - optsBuf: []byte{2, 1, 2, 3, 4, 5, 6}, - }, - { - name: "Invalid Length", - optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7}, - }, - { - name: "Multiple", - optsBuf: []byte{ - 2, 1, 2, 3, 4, 5, 6, 7, - 2, 1, 2, 3, 4, 5, 6, 8, - }, - }, - } - for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseLinkAddrCache: true, }) - e := channel.New(0, 1280, linkAddr0) + e := channel.New(1, 1280, nicLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) } - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.MessageBody()) - ns.SetTargetAddress(lladdr1) + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) + ns.SetTargetAddress(nicAddr) opts := ns.Options() - copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + opts.Serialize(test.nsOpts) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), TransportProtocol: header.ICMPv6ProtocolNumber, HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + SrcAddr: test.nsSrc, + DstAddr: test.nsDst, }) invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid @@ -796,44 +495,116 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) - ch := make(chan stack.LinkResolutionResult, 1) - err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) { - ch <- r - }) + if test.nsInvalid { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } - wantInvalid := uint64(0) - wantSucccess := true - if len(test.expectedLinkAddr) == 0 { - wantInvalid = 1 - wantSucccess = false - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{}) + if p, got := e.Read(); got { + t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) } - } else { - if err != nil { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err) + + // If we expected the NS to be invalid, we have nothing else to check. + return + } + + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + if test.performsLinkResolution { + p, got := e.ReadContext(context.Background()) + if !got { + t.Fatal("expected an NDP NS response") + } + + respNSDst := header.SolicitedNodeAddr(test.nsSrc) + var want stack.RouteInfo + want.NetProto = ProtocolNumber + want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst) + if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" { + t.Errorf("route info mismatch (-want +got):\n%s", diff) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(nicAddr), + checker.DstAddr(respNSDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(test.nsSrc), + checker.NDPNSOptions([]header.NDPOption{ + header.NDPSourceLinkLayerAddressOption(nicLinkAddr), + }), + )) + + ser := header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), } + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(pkt.MessageBody()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(true) + na.SetTargetAddress(test.nsSrc) + na.Options().Serialize(ser) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: test.nsSrc, + DstAddr: nicAddr, + }) + e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) } - if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { - t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) + p, got := e.ReadContext(context.Background()) + if !got { + t.Fatal("expected an NDP NA response") } - if got := invalid.Value(); got != wantInvalid { - t.Errorf("got invalid = %d, want = %d", got, wantInvalid) + + if p.Route.LocalAddress != test.naSrc { + t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc) } + if p.Route.LocalLinkAddress != nicLinkAddr { + t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) + } + if p.Route.RemoteAddress != test.naDst { + t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) + } + if p.Route.RemoteLinkAddress != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.naSrc), + checker.DstAddr(test.naDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNA( + checker.NDPNASolicitedFlag(test.naSolicited), + checker.NDPNATargetAddress(nicAddr), + checker.NDPNAOptions([]header.NDPOption{ + header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), + }), + )) }) } } -// TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests -// that receiving a valid NDP NA message with the Target Link Layer Address -// option does not result in a new entry in the neighbor cache for the target -// of the message. -func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) { +// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a +// valid NDP NA message with the Target Link Layer Address option does not +// result in a new entry in the neighbor cache for the target of the message. +func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { const nicID = 1 tests := []struct { @@ -867,7 +638,6 @@ func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *tes t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: true, }) e := channel.New(0, 1280, linkAddr0) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -944,235 +714,216 @@ func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *tes } func TestNDPValidation(t *testing.T) { - stacks := []struct { - name string - useNeighborCache bool + setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { + t.Helper() + + // Create a stack with the assigned link-local address lladdr0 + // and an endpoint to lladdr1. + s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) + + return s, ep + } + + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { + var extHdrs header.IPv6ExtHdrSerializer + if atomicFragment { + extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{}) + } + extHdrsLen := extHdrs.Length() + + ip := buffer.NewView(header.IPv6MinimumSize + extHdrsLen) + header.IPv6(ip).Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(payload) + extHdrsLen), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: hopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + ExtensionHeaders: extHdrs, + }) + vv := ip.ToVectorisedView() + vv.AppendView(payload) + ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + } + + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + + var sllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) + + types := []struct { + name string + typ header.ICMPv6Type + size int + extraData []byte + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + routerOnly bool }{ { - name: "linkAddrCache", - useNeighborCache: false, + name: "RouterSolicit", + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + routerOnly: true, }, { - name: "neighborCache", - useNeighborCache: true, + name: "RouterAdvert", + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + name: "NeighborSolicit", + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + extraData: sllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + name: "NeighborAdvert", + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + name: "RedirectMsg", + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, }, } - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - // Create a stack with the assigned link-local address lladdr0 - // and an endpoint to lladdr1. - s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache) + subTests := []struct { + name string + atomicFragment bool + hopLimit uint8 + code header.ICMPv6Code + valid bool + }{ + { + name: "Valid", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 0, + valid: true, + }, + { + name: "Fragmented", + atomicFragment: true, + hopLimit: header.NDPHopLimit, + code: 0, + valid: false, + }, + { + name: "Invalid hop limit", + atomicFragment: false, + hopLimit: header.NDPHopLimit - 1, + code: 0, + valid: false, + }, + { + name: "Invalid ICMPv6 code", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 1, + valid: false, + }, + } - return s, ep + for _, typ := range types { + for _, isRouter := range []bool{false, true} { + name := typ.name + if isRouter { + name += " (Router)" } - handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { - var extHdrs header.IPv6ExtHdrSerializer - if atomicFragment { - extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{}) - } - extHdrsLen := extHdrs.Length() + t.Run(name, func(t *testing.T) { + for _, test := range subTests { + t.Run(test.name, func(t *testing.T) { + s, ep := setup(t) - ip := buffer.NewView(header.IPv6MinimumSize + extHdrsLen) - header.IPv6(ip).Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload) + extHdrsLen), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: hopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - ExtensionHeaders: extHdrs, - }) - vv := ip.ToVectorisedView() - vv.AppendView(payload) - ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - } + if isRouter { + // Enabling forwarding makes the stack act as a router. + s.SetForwarding(ProtocolNumber, true) + } - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) + stats := s.Stats().ICMP.V6.PacketsReceived + invalid := stats.Invalid + routerOnly := stats.RouterOnlyPacketsDroppedByHost + typStat := typ.statCounter(stats) - var sllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetCode(test.code) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - types := []struct { - name string - typ header.ICMPv6Type - size int - extraData []byte - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - routerOnly bool - }{ - { - name: "RouterSolicit", - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }, - routerOnly: true, - }, - { - name: "RouterAdvert", - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }, - }, - { - name: "NeighborSolicit", - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - extraData: sllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }, - }, - { - name: "NeighborAdvert", - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }, - }, - { - name: "RedirectMsg", - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }, - }, - } + // Rx count of the NDP message should initially be 0. + if got := typStat.Value(); got != 0 { + t.Errorf("got %s = %d, want = 0", typ.name, got) + } - subTests := []struct { - name string - atomicFragment bool - hopLimit uint8 - code header.ICMPv6Code - valid bool - }{ - { - name: "Valid", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 0, - valid: true, - }, - { - name: "Fragmented", - atomicFragment: true, - hopLimit: header.NDPHopLimit, - code: 0, - valid: false, - }, - { - name: "Invalid hop limit", - atomicFragment: false, - hopLimit: header.NDPHopLimit - 1, - code: 0, - valid: false, - }, - { - name: "Invalid ICMPv6 code", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 1, - valid: false, - }, - } + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } - for _, typ := range types { - for _, isRouter := range []bool{false, true} { - name := typ.name - if isRouter { - name += " (Router)" - } + // RouterOnlyPacketsReceivedByHost count should initially be 0. + if got := routerOnly.Value(); got != 0 { + t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + + if t.Failed() { + t.FailNow() + } + + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) + + // Rx count of the NDP packet should have increased. + if got := typStat.Value(); got != 1 { + t.Errorf("got %s = %d, want = 1", typ.name, got) + } - t.Run(name, func(t *testing.T) { - for _, test := range subTests { - t.Run(test.name, func(t *testing.T) { - s, ep := setup(t) - - if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - routerOnly := stats.RouterOnlyPacketsDroppedByHost - typStat := typ.statCounter(stats) - - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - - // Rx count of the NDP message should initially be 0. - if got := typStat.Value(); got != 0 { - t.Errorf("got %s = %d, want = 0", typ.name, got) - } - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - - // RouterOnlyPacketsReceivedByHost count should initially be 0. - if got := routerOnly.Value(); got != 0 { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - - if t.Failed() { - t.FailNow() - } - - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) - - // Rx count of the NDP packet should have increased. - if got := typStat.Value(); got != 1 { - t.Errorf("got %s = %d, want = 1", typ.name, got) - } - - want := uint64(0) - if !test.valid { - // Invalid count should have increased. - want = 1 - } - if got := invalid.Value(); got != want { - t.Errorf("got invalid = %d, want = %d", got, want) - } - - want = 0 - if test.valid && !isRouter && typ.routerOnly { - // RouterOnlyPacketsReceivedByHost count should have increased. - want = 1 - } - if got := routerOnly.Value(); got != want { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) - } - - }) + want := uint64(0) + if !test.valid { + // Invalid count should have increased. + want = 1 } + if got := invalid.Value(); got != want { + t.Errorf("got invalid = %d, want = %d", got, want) + } + + want = 0 + if test.valid && !isRouter && typ.routerOnly { + // RouterOnlyPacketsReceivedByHost count should have increased. + want = 1 + } + if got := routerOnly.Value(); got != want { + t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) + } + }) } - } - }) + }) + } } - } // TestNeighborAdvertisementValidation tests that the NIC validates received @@ -1218,7 +969,6 @@ func TestNeighborAdvertisementValidation(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: true, }) e := channel.New(0, header.IPv6MinimumMTU, linkAddr0) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -1291,20 +1041,6 @@ func TestNeighborAdvertisementValidation(t *testing.T) { // NDP Router Advertisement packets, it validates the Router Advertisement // properly before handling them. func TestRouterAdvertValidation(t *testing.T) { - stacks := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - tests := []struct { name string src tcpip.Address @@ -1426,67 +1162,62 @@ func TestRouterAdvertValidation(t *testing.T) { }, } - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: stackTyp.useNeighborCache, - }) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } - icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(test.code) - copy(pkt.MessageBody(), test.ndpPayload) - payloadLength := hdr.UsedLength() - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) + icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(test.code) + copy(pkt.MessageBody(), test.ndpPayload) + payloadLength := hdr.UsedLength() + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - rxRA := stats.RouterAdvert + stats := s.Stats().ICMP.V6.PacketsReceived + invalid := stats.Invalid + rxRA := stats.RouterAdvert - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := rxRA.Value(); got != 0 { - t.Fatalf("got rxRA = %d, want = 0", got) - } + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) - if got := rxRA.Value(); got != 1 { - t.Fatalf("got rxRA = %d, want = 1", got) - } + if got := rxRA.Value(); got != 1 { + t.Fatalf("got rxRA = %d, want = 1", got) + } - if test.expectedSuccess { - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - } else { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - } - }) + if test.expectedSuccess { + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + } else { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } } }) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index ee23c9b98..49362333a 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -4,18 +4,6 @@ load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) go_template_instance( - name = "linkaddrentry_list", - out = "linkaddrentry_list.go", - package = "stack", - prefix = "linkAddrEntry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*linkAddrEntry", - "Linker": "*linkAddrEntry", - }, -) - -go_template_instance( name = "neighbor_entry_list", out = "neighbor_entry_list.go", package = "stack", @@ -62,8 +50,6 @@ go_library( "iptables_state.go", "iptables_targets.go", "iptables_types.go", - "linkaddrcache.go", - "linkaddrentry_list.go", "neighbor_cache.go", "neighbor_entry.go", "neighbor_entry_list.go", @@ -141,7 +127,6 @@ go_test( size = "small", srcs = [ "forwarding_test.go", - "linkaddrcache_test.go", "neighbor_cache_test.go", "neighbor_entry_test.go", "nic_test.go", diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 0cb9ec3a3..c987c1851 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -165,9 +165,9 @@ var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) type fwdTestNetworkProtocol struct { stack *Stack - neighborTable neighborTable + neigh *neighborCache addrResolveDelay time.Duration - onLinkAddressResolved func(neighborTable, tcpip.Address, tcpip.LinkAddress) + onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) mu struct { @@ -225,7 +225,7 @@ func (*fwdTestNetworkProtocol) Wait() {} func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { if fn := f.proto.onLinkAddressResolved; fn != nil { time.AfterFunc(f.proto.addrResolveDelay, func() { - fn(f.proto.neighborTable, addr, remoteLinkAddr) + fn(f.proto.neigh, addr, remoteLinkAddr) }) } return nil @@ -361,14 +361,13 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco panic("not implemented") } -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { proto.stack = s return proto }}, - UseNeighborCache: useNeighborCache, }) // Enable forwarding. @@ -406,7 +405,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC } if l, ok := nic.linkAddrResolvers[fwdTestNetNumber]; ok { - proto.neighborTable = l.neighborTable + proto.neigh = &l.neigh } // Route all packets to NIC 2. @@ -422,121 +421,85 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC } func TestForwardingWithStaticResolver(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + // Create a network protocol with a static resolver. + proto := &fwdTestNetworkProtocol{ + onResolveStaticAddress: + // The network address 3 is resolved to the link address "c". + func(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "\x03" { + return "c", true + } + return "", false }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - // Create a network protocol with a static resolver. - proto := &fwdTestNetworkProtocol{ - onResolveStaticAddress: - // The network address 3 is resolved to the link address "c". - func(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "\x03" { - return "c", true - } - return "", false - }, - } - - ep1, ep2 := fwdTestNetFactory(t, proto, test.useNeighborCache) + ep1, ep2 := fwdTestNetFactory(t, proto) - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - var p fwdTestPacketInfo + var p fwdTestPacketInfo - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } + select { + case p = <-ep2.C: + default: + t.Fatal("packet not forwarded") + } - // Test that the static address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - }) + // Test that the static address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) } } func TestForwardingWithFakeResolver(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + proto := fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + t.Helper() + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) + } + // Any address will be resolved to the link address "c". + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) }, } + ep1, ep2 := fwdTestNetFactory(t, &proto) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any address will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - }) + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) } } @@ -546,7 +509,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // Whether or not we use the neighbor cache here does not matter since // neither linkAddrCache nor neighborCache will be used. - ep1, ep2 := fwdTestNetFactory(t, proto, false /* useNeighborCache */) + ep1, ep2 := fwdTestNetFactory(t, proto) // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -566,12 +529,12 @@ func TestForwardingWithNoResolver(t *testing.T) { func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { proto := &fwdTestNetworkProtocol{ addrResolveDelay: 50 * time.Millisecond, - onLinkAddressResolved: func(neighborTable, tcpip.Address, tcpip.LinkAddress) { + onLinkAddressResolved: func(*neighborCache, tcpip.Address, tcpip.LinkAddress) { // Don't resolve the link address. }, } - ep1, ep2 := fwdTestNetFactory(t, proto, true /* useNeighborCache */) + ep1, ep2 := fwdTestNetFactory(t, proto) const numPackets int = 5 // These packets will all be enqueued in the packet queue to wait for link @@ -596,300 +559,227 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { } func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + proto := fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + t.Helper() + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) + } + // Only packets to address 3 will be resolved to the + // link address "c". + if addr == "\x03" { + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + } }, } + ep1, ep2 := fwdTestNetFactory(t, &proto) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - } - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) - - // Inject an inbound packet to address 4 on NIC 1. This packet should - // not be forwarded. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 4 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf = buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } + // Inject an inbound packet to address 4 on NIC 1. This packet should + // not be forwarded. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 4 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf = buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - }) + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) } } func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + proto := fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + t.Helper() + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) }, } + ep1, ep2 := fwdTestNetFactory(t, &proto) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) - - // Inject two inbound packets to address 3 on NIC 1. - for i := 0; i < 2; i++ { - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } + // Inject two inbound packets to address 3 on NIC 1. + for i := 0; i < 2; i++ { + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - for i := 0; i < 2; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } - }) + for i := 0; i < 2; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } } } func TestForwardingWithFakeResolverManyPackets(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + proto := fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + t.Helper() + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) }, } + ep1, ep2 := fwdTestNetFactory(t, &proto) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) - - for i := 0; i < maxPendingPacketsPerResolution+5; i++ { - // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - // Set the packet sequence number. - binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } + for i := 0; i < maxPendingPacketsPerResolution+5; i++ { + // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + // Set the packet sequence number. + binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - for i := 0; i < maxPendingPacketsPerResolution; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - b := PayloadSince(p.Pkt.NetworkHeader()) - if b[dstAddrOffset] != 3 { - t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) - } - if len(b) < fwdTestNetHeaderLen+2 { - t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) - } - seqNumBuf := b[fwdTestNetHeaderLen:] - - // The first 5 packets should not be forwarded so the sequence number should - // start with 5. - want := uint16(i + 5) - if n := binary.BigEndian.Uint16(seqNumBuf); n != want { - t.Fatalf("got the packet #%d, want = #%d", n, want) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } - }) + for i := 0; i < maxPendingPacketsPerResolution; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := PayloadSince(p.Pkt.NetworkHeader()) + if b[dstAddrOffset] != 3 { + t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) + } + if len(b) < fwdTestNetHeaderLen+2 { + t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) + } + seqNumBuf := b[fwdTestNetHeaderLen:] + + // The first 5 packets should not be forwarded so the sequence number should + // start with 5. + want := uint16(i + 5) + if n := binary.BigEndian.Uint16(seqNumBuf); n != want { + t.Fatalf("got the packet #%d, want = #%d", n, want) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } } } func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - proto *fwdTestNetworkProtocol - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + proto := fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + t.Helper() + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) }, } + ep1, ep2 := fwdTestNetFactory(t, &proto) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) - - for i := 0; i < maxPendingResolutions+5; i++ { - // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. - // Each packet has a different destination address (3 to - // maxPendingResolutions + 7). - buf := buffer.NewView(30) - buf[dstAddrOffset] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } + for i := 0; i < maxPendingResolutions+5; i++ { + // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. + // Each packet has a different destination address (3 to + // maxPendingResolutions + 7). + buf := buffer.NewView(30) + buf[dstAddrOffset] = byte(3 + i) + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - for i := 0; i < maxPendingResolutions; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - // The first 5 packets (address 3 to 7) should not be forwarded - // because their address resolutions are interrupted. - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } - }) + for i := 0; i < maxPendingResolutions; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // The first 5 packets (address 3 to 7) should not be forwarded + // because their address resolutions are interrupted. + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } } } diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go deleted file mode 100644 index 9ec04dd50..000000000 --- a/pkg/tcpip/stack/linkaddrcache.go +++ /dev/null @@ -1,359 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" -) - -const linkAddrCacheSize = 512 // max cache entries - -// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses. -// -// The entries are stored in a ring buffer, oldest entry replaced first. -// -// This struct is safe for concurrent use. -type linkAddrCache struct { - nic *nic - - linkRes LinkAddressResolver - - // ageLimit is how long a cache entry is valid for. - ageLimit time.Duration - - // resolutionTimeout is the amount of time to wait for a link request to - // resolve an address. - resolutionTimeout time.Duration - - // resolutionAttempts is the number of times an address is attempted to be - // resolved before failing. - resolutionAttempts int - - mu struct { - sync.Mutex - table map[tcpip.Address]*linkAddrEntry - lru linkAddrEntryList - } -} - -// entryState controls the state of a single entry in the cache. -type entryState int - -const ( - // incomplete means that there is an outstanding request to resolve the - // address. This is the initial state. - incomplete entryState = iota - // ready means that the address has been resolved and can be used. - ready -) - -// String implements Stringer. -func (s entryState) String() string { - switch s { - case incomplete: - return "incomplete" - case ready: - return "ready" - default: - return fmt.Sprintf("unknown(%d)", s) - } -} - -// A linkAddrEntry is an entry in the linkAddrCache. -// This struct is thread-compatible. -type linkAddrEntry struct { - // linkAddrEntryEntry access is synchronized by the linkAddrCache lock. - linkAddrEntryEntry - - cache *linkAddrCache - - mu struct { - sync.RWMutex - - addr tcpip.Address - linkAddr tcpip.LinkAddress - expiration time.Time - s entryState - - // done is closed when address resolution is complete. It is nil iff s is - // incomplete and resolution is not yet in progress. - done chan struct{} - - // onResolve is called with the result of address resolution. - onResolve []func(LinkResolutionResult) - } -} - -func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { - res := LinkResolutionResult{LinkAddress: linkAddr, Success: len(linkAddr) != 0} - for _, callback := range e.mu.onResolve { - callback(res) - } - e.mu.onResolve = nil - if ch := e.mu.done; ch != nil { - close(ch) - e.mu.done = nil - // Dequeue the pending packets in a new goroutine to not hold up the current - // goroutine as writing packets may be a costly operation. - // - // At the time of writing, when writing packets, a neighbor's link address - // is resolved (which ends up obtaining the entry's lock) while holding the - // link resolution queue's lock. Dequeuing packets in a new goroutine avoids - // a lock ordering violation. - go e.cache.nic.linkResQueue.dequeue(ch, linkAddr, len(linkAddr) != 0) - } -} - -// changeStateLocked sets the entry's state to ns. -// -// The entry's expiration is bumped up to the greater of itself and the passed -// expiration; the zero value indicates immediate expiration, and is set -// unconditionally - this is an implementation detail that allows for entries -// to be reused. -// -// Precondition: e.mu must be locked -func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { - if e.mu.s == incomplete && ns == ready { - e.notifyCompletionLocked(e.mu.linkAddr) - } - - if expiration.IsZero() || expiration.After(e.mu.expiration) { - e.mu.expiration = expiration - } - e.mu.s = ns -} - -// add adds a k -> v mapping to the cache. -func (c *linkAddrCache) add(k tcpip.Address, v tcpip.LinkAddress) { - // Calculate expiration time before acquiring the lock, since expiration is - // relative to the time when information was learned, rather than when it - // happened to be inserted into the cache. - expiration := time.Now().Add(c.ageLimit) - - c.mu.Lock() - entry := c.getOrCreateEntryLocked(k) - entry.mu.Lock() - defer entry.mu.Unlock() - c.mu.Unlock() - - entry.mu.linkAddr = v - entry.changeStateLocked(ready, expiration) -} - -// getOrCreateEntryLocked retrieves a cache entry associated with k. The -// returned entry is always refreshed in the cache (it is reachable via the -// map, and its place is bumped in LRU). -// -// If a matching entry exists in the cache, it is returned. If no matching -// entry exists and the cache is full, an existing entry is evicted via LRU, -// reset to state incomplete, and returned. If no matching entry exists and the -// cache is not full, a new entry with state incomplete is allocated and -// returned. -func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { - if entry, ok := c.mu.table[k]; ok { - c.mu.lru.Remove(entry) - c.mu.lru.PushFront(entry) - return entry - } - var entry *linkAddrEntry - if len(c.mu.table) == linkAddrCacheSize { - entry = c.mu.lru.Back() - entry.mu.Lock() - - delete(c.mu.table, entry.mu.addr) - c.mu.lru.Remove(entry) - - // Wake waiters and mark the soon-to-be-reused entry as expired. - entry.notifyCompletionLocked("" /* linkAddr */) - entry.mu.Unlock() - } else { - entry = new(linkAddrEntry) - } - - *entry = linkAddrEntry{ - cache: c, - } - entry.mu.Lock() - entry.mu.addr = k - entry.mu.s = incomplete - entry.mu.Unlock() - c.mu.table[k] = entry - c.mu.lru.PushFront(entry) - return entry -} - -// get reports any known link address for addr. -func (c *linkAddrCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - c.mu.Lock() - entry := c.getOrCreateEntryLocked(addr) - entry.mu.Lock() - defer entry.mu.Unlock() - c.mu.Unlock() - - switch s := entry.mu.s; s { - case ready: - if !time.Now().After(entry.mu.expiration) { - // Not expired. - if onResolve != nil { - onResolve(LinkResolutionResult{LinkAddress: entry.mu.linkAddr, Success: true}) - } - return entry.mu.linkAddr, nil, nil - } - - entry.changeStateLocked(incomplete, time.Time{}) - fallthrough - case incomplete: - if onResolve != nil { - entry.mu.onResolve = append(entry.mu.onResolve, onResolve) - } - if entry.mu.done == nil { - entry.mu.done = make(chan struct{}) - go c.startAddressResolution(addr, localAddr, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. - } - return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{} - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) - } -} - -func (c *linkAddrCache) startAddressResolution(k tcpip.Address, localAddr tcpip.Address, done <-chan struct{}) { - for i := 0; ; i++ { - // Send link request, then wait for the timeout limit and check - // whether the request succeeded. - c.linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */) - - select { - case now := <-time.After(c.resolutionTimeout): - if stop := c.checkLinkRequest(now, k, i); stop { - return - } - case <-done: - return - } - } -} - -// checkLinkRequest checks whether previous attempt to resolve address has -// succeeded and mark the entry accordingly. Returns true if request can stop, -// false if another request should be sent. -func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt int) bool { - c.mu.Lock() - defer c.mu.Unlock() - entry, ok := c.mu.table[k] - if !ok { - // Entry was evicted from the cache. - return true - } - entry.mu.Lock() - defer entry.mu.Unlock() - - switch s := entry.mu.s; s { - case ready: - // Entry was made ready by resolver. - case incomplete: - if attempt+1 < c.resolutionAttempts { - // No response yet, need to send another ARP request. - return false - } - // Max number of retries reached, delete entry. - entry.notifyCompletionLocked("" /* linkAddr */) - delete(c.mu.table, k) - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) - } - return true -} - -func (c *linkAddrCache) init(nic *nic, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int, linkRes LinkAddressResolver) { - *c = linkAddrCache{ - nic: nic, - linkRes: linkRes, - ageLimit: ageLimit, - resolutionTimeout: resolutionTimeout, - resolutionAttempts: resolutionAttempts, - } - - c.mu.Lock() - c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize) - c.mu.Unlock() -} - -var _ neighborTable = (*linkAddrCache)(nil) - -func (*linkAddrCache) neighbors() ([]NeighborEntry, tcpip.Error) { - return nil, &tcpip.ErrNotSupported{} -} - -func (c *linkAddrCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) { - c.add(addr, linkAddr) -} - -func (*linkAddrCache) remove(addr tcpip.Address) tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -func (*linkAddrCache) removeAll() tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -func (c *linkAddrCache) handleProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress) { - if len(linkAddr) != 0 { - // NUD allows probes without a link address but linkAddrCache - // is a simple neighbor table which does not implement NUD. - // - // As per RFC 4861 section 4.3, - // - // Source link-layer address - // The link-layer address for the sender. MUST NOT be - // included when the source IP address is the - // unspecified address. Otherwise, on link layers - // that have addresses this option MUST be included in - // multicast solicitations and SHOULD be included in - // unicast solicitations. - c.add(addr, linkAddr) - } -} - -func (c *linkAddrCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { - if len(linkAddr) != 0 { - // NUD allows confirmations without a link address but linkAddrCache - // is a simple neighbor table which does not implement NUD. - // - // As per RFC 4861 section 4.4, - // - // Target link-layer address - // The link-layer address for the target, i.e., the - // sender of the advertisement. This option MUST be - // included on link layers that have addresses when - // responding to multicast solicitations. When - // responding to a unicast Neighbor Solicitation this - // option SHOULD be included. - c.add(addr, linkAddr) - } -} - -func (c *linkAddrCache) handleUpperLevelConfirmation(tcpip.Address) {} - -func (*linkAddrCache) nudConfig() (NUDConfigurations, tcpip.Error) { - return NUDConfigurations{}, &tcpip.ErrNotSupported{} -} - -func (*linkAddrCache) setNUDConfig(NUDConfigurations) tcpip.Error { - return &tcpip.ErrNotSupported{} -} diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go deleted file mode 100644 index d2de0de1e..000000000 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ /dev/null @@ -1,291 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - "math" - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" -) - -type testaddr struct { - addr tcpip.Address - linkAddr tcpip.LinkAddress -} - -var testAddrs = func() []testaddr { - var addrs []testaddr - for i := 0; i < 4*linkAddrCacheSize; i++ { - addr := fmt.Sprintf("Addr%06d", i) - addrs = append(addrs, testaddr{ - addr: tcpip.Address(addr), - linkAddr: tcpip.LinkAddress("Link" + addr), - }) - } - return addrs -}() - -type testLinkAddressResolver struct { - cache *linkAddrCache - delay time.Duration - onLinkAddressRequest func() -} - -func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error { - // TODO(gvisor.dev/issue/5141): Use a fake clock. - time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) - if f := r.onLinkAddressRequest; f != nil { - f() - } - return nil -} - -func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { - for _, ta := range testAddrs { - if ta.addr == addr { - r.cache.add(ta.addr, ta.linkAddr) - break - } - } -} - -func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "broadcast" { - return "mac_broadcast", true - } - return "", false -} - -func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return 1 -} - -func getBlocking(c *linkAddrCache, addr tcpip.Address) (tcpip.LinkAddress, tcpip.Error) { - var attemptedResolution bool - for { - got, ch, err := c.get(addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - if attemptedResolution { - return got, &tcpip.ErrTimeout{} - } - attemptedResolution = true - <-ch - continue - } - return got, err - } -} - -func newEmptyNIC() *nic { - n := &nic{} - n.linkResQueue.init(n) - return n -} - -func TestCacheOverflow(t *testing.T) { - var c linkAddrCache - c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) - for i := len(testAddrs) - 1; i >= 0; i-- { - e := testAddrs[i] - c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, "", nil) - if err != nil { - t.Errorf("insert %d, c.get(%s, '', nil): %s", i, e.addr, err) - } - if got != e.linkAddr { - t.Errorf("insert %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) - } - } - // Expect to find at least half of the most recent entries. - for i := 0; i < linkAddrCacheSize/2; i++ { - e := testAddrs[i] - got, _, err := c.get(e.addr, "", nil) - if err != nil { - t.Errorf("check %d, c.get(%s, '', nil): %s", i, e.addr, err) - } - if got != e.linkAddr { - t.Errorf("check %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) - } - } - // The earliest entries should no longer be in the cache. - c.mu.Lock() - defer c.mu.Unlock() - for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { - e := testAddrs[i] - if entry, ok := c.mu.table[e.addr]; ok { - t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry) - } - } -} - -func TestCacheConcurrent(t *testing.T) { - var c linkAddrCache - linkRes := &testLinkAddressResolver{cache: &c} - c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, linkRes) - - var wg sync.WaitGroup - for r := 0; r < 16; r++ { - wg.Add(1) - go func() { - for _, e := range testAddrs { - c.add(e.addr, e.linkAddr) - } - wg.Done() - }() - } - wg.Wait() - - // All goroutines add in the same order and add more values than - // can fit in the cache, so our eviction strategy requires that - // the last entry be present and the first be missing. - e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, "", nil) - if err != nil { - t.Errorf("c.get(%s, '', nil): %s", e.addr, err) - } - if got != e.linkAddr { - t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) - } - - e = testAddrs[0] - c.mu.Lock() - defer c.mu.Unlock() - if entry, ok := c.mu.table[e.addr]; ok { - t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry) - } -} - -func TestCacheAgeLimit(t *testing.T) { - var c linkAddrCache - linkRes := &testLinkAddressResolver{cache: &c} - c.init(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3, linkRes) - - e := testAddrs[0] - c.add(e.addr, e.linkAddr) - time.Sleep(50 * time.Millisecond) - _, _, err := c.get(e.addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.get(%s, '', nil) = %s, want = ErrWouldBlock", e.addr, err) - } -} - -func TestCacheReplace(t *testing.T) { - var c linkAddrCache - c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) - e := testAddrs[0] - l2 := e.linkAddr + "2" - c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, "", nil) - if err != nil { - t.Errorf("c.get(%s, '', nil): %s", e.addr, err) - } - if got != e.linkAddr { - t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) - } - - c.add(e.addr, l2) - got, _, err = c.get(e.addr, "", nil) - if err != nil { - t.Errorf("c.get(%s, '', nil): %s", e.addr, err) - } - if got != l2 { - t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, l2) - } -} - -func TestCacheResolution(t *testing.T) { - // There is a race condition causing this test to fail when the executor - // takes longer than the resolution timeout to call linkAddrCache.get. This - // is especially common when this test is run with gotsan. - // - // Using a large resolution timeout decreases the probability of experiencing - // this race condition and does not affect how long this test takes to run. - var c linkAddrCache - linkRes := &testLinkAddressResolver{cache: &c} - c.init(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1, linkRes) - for i, ta := range testAddrs { - got, err := getBlocking(&c, ta.addr) - if err != nil { - t.Errorf("check %d, getBlocking(_, %s): %s", i, ta.addr, err) - } - if got != ta.linkAddr { - t.Errorf("check %d, got getBlocking(_, %s) = %s, want = %s", i, ta.addr, got, ta.linkAddr) - } - } - - // Check that after resolved, address stays in the cache and never returns WouldBlock. - for i := 0; i < 10; i++ { - e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, "", nil) - if err != nil { - t.Errorf("c.get(%s, '', nil): %s", e.addr, err) - } - if got != e.linkAddr { - t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) - } - } -} - -func TestCacheResolutionFailed(t *testing.T) { - var c linkAddrCache - linkRes := &testLinkAddressResolver{cache: &c} - c.init(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5, linkRes) - - var requestCount uint32 - linkRes.onLinkAddressRequest = func() { - atomic.AddUint32(&requestCount, 1) - } - - // First, sanity check that resolution is working... - e := testAddrs[0] - got, err := getBlocking(&c, e.addr) - if err != nil { - t.Errorf("getBlocking(_, %s): %s", e.addr, err) - } - if got != e.linkAddr { - t.Errorf("got getBlocking(_, %s) = %s, want = %s", e.addr, got, e.linkAddr) - } - - before := atomic.LoadUint32(&requestCount) - - e.addr += "2" - a, err := getBlocking(&c, e.addr) - if _, ok := err.(*tcpip.ErrTimeout); !ok { - t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) - } - - if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { - t.Errorf("got link address request count = %d, want = %d", got, want) - } -} - -func TestCacheResolutionTimeout(t *testing.T) { - resolverDelay := 500 * time.Millisecond - expiration := resolverDelay / 10 - var c linkAddrCache - linkRes := &testLinkAddressResolver{cache: &c, delay: resolverDelay} - c.init(newEmptyNIC(), expiration, 1*time.Millisecond, 3, linkRes) - - e := testAddrs[0] - a, err := getBlocking(&c, e.addr) - if _, ok := err.(*tcpip.ErrTimeout); !ok { - t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) - } -} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 0238605af..4fa4101b0 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -2769,7 +2769,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // stack.Stack will have a default route through the router (llAddr3) installed // and a static link-address (linkAddr3) added to the link address cache for the // router. -func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useNeighborCache bool) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { t.Helper() ndpDisp := &ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), @@ -2785,7 +2785,6 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN NDPDisp: ndpDisp, })}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - UseNeighborCache: useNeighborCache, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -2868,126 +2867,108 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA // TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when // receiving a PI with 0 preferred lifetime. func TestAutoGenAddrDeprecateFromPI(t *testing.T) { - stacks := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - const nicID = 1 + const nicID = 1 - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected addr auto gen event") + } + } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) + } - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - // Receive PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - expectPrimaryAddr(addr1) + // Receive PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + expectPrimaryAddr(addr1) - // Deprecate addr for prefix1 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - // addr should still be the primary endpoint as there are no other addresses. - expectPrimaryAddr(addr1) + // Deprecate addr for prefix1 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + // addr should still be the primary endpoint as there are no other addresses. + expectPrimaryAddr(addr1) - // Refresh lifetimes of addr generated from prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) + // Refresh lifetimes of addr generated from prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Deprecate addr for prefix2 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr1 should be the primary endpoint now since addr2 is deprecated but - // addr1 is not. - expectPrimaryAddr(addr1) - // addr2 is deprecated but if explicitly requested, it should be used. - fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } + // Deprecate addr for prefix2 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr1 should be the primary endpoint now since addr2 is deprecated but + // addr1 is not. + expectPrimaryAddr(addr1) + // addr2 is deprecated but if explicitly requested, it should be used. + fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) + } - // Another PI w/ 0 preferred lifetime should not result in a deprecation - // event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } + // Another PI w/ 0 preferred lifetime should not result in a deprecation + // event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) + } - // Refresh lifetimes of addr generated from prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) - }) + // Refresh lifetimes of addr generated from prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: } + expectPrimaryAddr(addr2) } // TestAutoGenAddrJobDeprecation tests that an address is properly deprecated @@ -2997,232 +2978,214 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { const newMinVL = 2 newMinVLDuration := newMinVL * time.Second - stacks := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + saved := ipv6.MinPrefixInformationValidLifetimeForUpdate + defer func() { + ipv6.MinPrefixInformationValidLifetimeForUpdate = saved + }() + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected addr auto gen event") + } + } - expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() + expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(timeout): - t.Fatal("timed out waiting for addr auto gen event") - } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) + } - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Receive a PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr1) + // Receive a PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr1) - // Refresh lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) + // Refresh lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since addr1 is deprecated but - // addr2 is not. - expectPrimaryAddr(addr2) - // addr1 is deprecated but if explicitly requested, it should be used. - fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since addr1 is deprecated but + // addr2 is not. + expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. + fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make - // sure we do not get a deprecation event again. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make + // sure we do not get a deprecation event again. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Refresh lifetimes for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - // addr1 is the primary endpoint again since it is non-deprecated now. - expectPrimaryAddr(addr1) + // Refresh lifetimes for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + // addr1 is the primary endpoint again since it is non-deprecated now. + expectPrimaryAddr(addr1) - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since it is not deprecated. - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since it is not deprecated. + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Wait for addr of prefix1 to be invalidated. + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Refresh both lifetimes for addr of prefix2 to the same value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } + // Refresh both lifetimes for addr of prefix2 to the same value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } - // Wait for a deprecation then invalidation events, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation handlers could be handled in - // either deprecation then invalidation, or invalidation then deprecation - // (which should be cancelled by the invalidation handler). + // Wait for a deprecation then invalidation events, or just an invalidation + // event. We need to cover both cases but cannot deterministically hit both + // cases because the deprecation and invalidation handlers could be handled in + // either deprecation then invalidation, or invalidation then deprecation + // (which should be cancelled by the invalidation handler). + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { + // If we get a deprecation event first, we should get an invalidation + // event almost immediately after. select { case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we should not get a deprecation - // event after. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultAsyncNegativeEventTimeout): - } - } else { - t.Fatalf("got unexpected auto-generated event") + if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should not have %s in the list of addresses", addr2) - } - // Should not have any primary endpoints. - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { + // If we get an invalidation event first, we should not get a deprecation + // event after. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + case <-time.After(defaultAsyncNegativeEventTimeout): } - defer ep.Close() - ep.SocketOptions().SetV6Only(true) + } else { + t.Fatalf("got unexpected auto-generated event") + } + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should not have %s in the list of addresses", addr2) + } + // Should not have any primary endpoints. + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) + } + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + ep.SocketOptions().SetV6Only(true) - { - err := ep.Connect(dstAddr) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{}) - } - } - }) + { + err := ep.Connect(dstAddr) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{}) + } } } @@ -3543,126 +3506,108 @@ func TestAutoGenAddrRemoval(t *testing.T) { func TestAutoGenAddrAfterRemoval(t *testing.T) { const nicID = 1 - stacks := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } } - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + // Receive a PI to auto-generate addr1 with a large valid and preferred + // lifetime. + const largeLifetimeSeconds = 999 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr1, newAddr) + expectPrimaryAddr(addr1) - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } + // Add addr2 as a static address. + protoAddr2 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr2, + } + if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + } + // addr2 should be more preferred now since it is at the front of the primary + // list. + expectPrimaryAddr(addr2) - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + // Get a route using addr2 to increment its reference count then remove it + // to leave it in the permanentExpired state. + r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) + } + defer r.Release() + if err := s.RemoveAddress(nicID, addr2.Address); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) + } + // addr1 should be preferred again since addr2 is in the expired state. + expectPrimaryAddr(addr1) - // Receive a PI to auto-generate addr1 with a large valid and preferred - // lifetime. - const largeLifetimeSeconds = 999 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr1, newAddr) - expectPrimaryAddr(addr1) - - // Add addr2 as a static address. - protoAddr2 := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addr2, - } - if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) - } - // addr2 should be more preferred now since it is at the front of the primary - // list. - expectPrimaryAddr(addr2) - - // Get a route using addr2 to increment its reference count then remove it - // to leave it in the permanentExpired state. - r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) - } - defer r.Release() - if err := s.RemoveAddress(nicID, addr2.Address); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) - } - // addr1 should be preferred again since addr2 is in the expired state. - expectPrimaryAddr(addr1) - - // Receive a PI to auto-generate addr2 as valid and preferred. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr2 should be more preferred now that it is closer to the front of the - // primary list and not deprecated. - expectPrimaryAddr(addr2) - - // Removing the address should result in an invalidation event immediately. - // It should still be in the permanentExpired state because r is still held. - // - // We remove addr2 here to make sure addr2 was marked as a SLAAC address - // (it was previously marked as a static address). - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - // addr1 should be more preferred since addr2 is in the expired state. - expectPrimaryAddr(addr1) - - // Receive a PI to auto-generate addr2 as valid and deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr1 should still be more preferred since addr2 is deprecated, even though - // it is closer to the front of the primary list. - expectPrimaryAddr(addr1) - - // Receive a PI to refresh addr2's preferred lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto gen addr event") - default: - } - // addr2 should be more preferred now that it is not deprecated. - expectPrimaryAddr(addr2) + // Receive a PI to auto-generate addr2 as valid and preferred. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr2 should be more preferred now that it is closer to the front of the + // primary list and not deprecated. + expectPrimaryAddr(addr2) - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - expectPrimaryAddr(addr1) - }) + // Removing the address should result in an invalidation event immediately. + // It should still be in the permanentExpired state because r is still held. + // + // We remove addr2 here to make sure addr2 was marked as a SLAAC address + // (it was previously marked as a static address). + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + // addr1 should be more preferred since addr2 is in the expired state. + expectPrimaryAddr(addr1) + + // Receive a PI to auto-generate addr2 as valid and deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr1 should still be more preferred since addr2 is deprecated, even though + // it is closer to the front of the primary list. + expectPrimaryAddr(addr1) + + // Receive a PI to refresh addr2's preferred lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto gen addr event") + default: + } + // addr2 should be more preferred now that it is not deprecated. + expectPrimaryAddr(addr2) + + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + expectPrimaryAddr(addr1) } // TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 1409604df..a77fe575a 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -266,30 +266,6 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { n.state.SetConfig(config) } -var _ neighborTable = (*neighborCache)(nil) - -func (n *neighborCache) neighbors() ([]NeighborEntry, tcpip.Error) { - return n.entries(), nil -} - -func (n *neighborCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - entry, ch, err := n.entry(addr, localAddr, onResolve) - return entry.LinkAddr, ch, err -} - -func (n *neighborCache) remove(addr tcpip.Address) tcpip.Error { - if !n.removeEntry(addr) { - return &tcpip.ErrBadAddress{} - } - - return nil -} - -func (n *neighborCache) removeAll() tcpip.Error { - n.clear() - return nil -} - // handleProbe handles a neighbor probe as defined by RFC 4861 section 7.2.3. // // Validation of the probe is expected to be handled by the caller. @@ -331,17 +307,8 @@ func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { } } -func (n *neighborCache) nudConfig() (NUDConfigurations, tcpip.Error) { - return n.config(), nil -} - -func (n *neighborCache) setNUDConfig(c NUDConfigurations) tcpip.Error { - n.setConfig(c) - return nil -} - -func newNeighborCache(nic *nic, r LinkAddressResolver) *neighborCache { - n := &neighborCache{ +func (n *neighborCache) init(nic *nic, r LinkAddressResolver) { + *n = neighborCache{ nic: nic, state: NewNUDState(nic.stack.nudConfigs, nic.stack.randomGenerator), linkRes: r, @@ -349,5 +316,4 @@ func newNeighborCache(nic *nic, r LinkAddressResolver) *neighborCache { n.mu.Lock() n.mu.cache = make(map[tcpip.Address]*neighborEntry, neighborCacheSize) n.mu.Unlock() - return n } diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 6ff215e99..d60a49fe8 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -84,7 +84,7 @@ func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, cl entries: newTestEntryStore(), delay: typicalLatency, } - linkRes.neigh = newNeighborCache(&nic{ + linkRes.neigh.init(&nic{ stack: &Stack{ clock: clock, nudDisp: nudDisp, @@ -187,7 +187,7 @@ func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { // neighbor probe. type testNeighborResolver struct { clock tcpip.Clock - neigh *neighborCache + neigh neighborCache entries *testEntryStore delay time.Duration onLinkAddressRequest func() diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 725b981b5..f80d4ef8a 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -239,16 +239,16 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e var linkRes entryTestLinkResolver // Stub out the neighbor cache to verify deletion from the cache. - neigh := newNeighborCache(&nic, &linkRes) - l := linkResolver{ - resolver: &linkRes, - neighborTable: neigh, - } - entry := newNeighborEntry(neigh, entryTestAddr1 /* remoteAddr */, neigh.state) - neigh.mu.Lock() - neigh.mu.cache[entryTestAddr1] = entry - neigh.mu.Unlock() - nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]linkResolver{ + l := &linkResolver{ + resolver: &linkRes, + } + l.neigh.init(&nic, &linkRes) + + entry := newNeighborEntry(&l.neigh, entryTestAddr1 /* remoteAddr */, l.neigh.state) + l.neigh.mu.Lock() + l.neigh.mu.cache[entryTestAddr1] = entry + l.neigh.mu.Unlock() + nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]*linkResolver{ header.IPv6ProtocolNumber: l, } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f85652093..00cfba35a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -24,37 +24,23 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -type neighborTable interface { - neighbors() ([]NeighborEntry, tcpip.Error) - addStaticEntry(tcpip.Address, tcpip.LinkAddress) - get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) - remove(tcpip.Address) tcpip.Error - removeAll() tcpip.Error - - handleProbe(tcpip.Address, tcpip.LinkAddress) - handleConfirmation(tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) - handleUpperLevelConfirmation(tcpip.Address) - - nudConfig() (NUDConfigurations, tcpip.Error) - setNUDConfig(NUDConfigurations) tcpip.Error -} - -var _ NetworkInterface = (*nic)(nil) - type linkResolver struct { resolver LinkAddressResolver - neighborTable neighborTable + neigh neighborCache } func (l *linkResolver) getNeighborLinkAddress(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - return l.neighborTable.get(addr, localAddr, onResolve) + entry, ch, err := l.neigh.entry(addr, localAddr, onResolve) + return entry.LinkAddr, ch, err } func (l *linkResolver) confirmReachable(addr tcpip.Address) { - l.neighborTable.handleUpperLevelConfirmation(addr) + l.neigh.handleUpperLevelConfirmation(addr) } +var _ NetworkInterface = (*nic)(nil) + // nic represents a "network interface card" to which the networking stack is // attached. type nic struct { @@ -70,7 +56,7 @@ type nic struct { // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint - linkAddrResolvers map[tcpip.NetworkProtocolNumber]linkResolver + linkAddrResolvers map[tcpip.NetworkProtocolNumber]*linkResolver // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. // @@ -165,7 +151,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC context: ctx, stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]linkResolver), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]*linkResolver), } nic.linkResQueue.init(nic) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) @@ -185,17 +171,8 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC if resolutionRequired { if r, ok := netEP.(LinkAddressResolver); ok { - l := linkResolver{ - resolver: r, - } - - if stack.useNeighborCache { - l.neighborTable = newNeighborCache(nic, r) - } else { - cache := new(linkAddrCache) - cache.init(nic, ageLimit, resolutionTimeout, resolutionAttempts, r) - l.neighborTable = cache - } + l := &linkResolver{resolver: r} + l.neigh.init(nic, r) nic.linkAddrResolvers[r.LinkAddressProtocol()] = l } } @@ -640,7 +617,7 @@ func (n *nic) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.Netwo func (n *nic) neighbors(protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { - return linkRes.neighborTable.neighbors() + return linkRes.neigh.entries(), nil } return nil, &tcpip.ErrNotSupported{} @@ -648,7 +625,7 @@ func (n *nic) neighbors(protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, func (n *nic) addStaticNeighbor(addr tcpip.Address, protocol tcpip.NetworkProtocolNumber, linkAddress tcpip.LinkAddress) tcpip.Error { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { - linkRes.neighborTable.addStaticEntry(addr, linkAddress) + linkRes.neigh.addStaticEntry(addr, linkAddress) return nil } @@ -657,7 +634,10 @@ func (n *nic) addStaticNeighbor(addr tcpip.Address, protocol tcpip.NetworkProtoc func (n *nic) removeNeighbor(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { - return linkRes.neighborTable.remove(addr) + if !linkRes.neigh.removeEntry(addr) { + return &tcpip.ErrBadAddress{} + } + return nil } return &tcpip.ErrNotSupported{} @@ -665,7 +645,8 @@ func (n *nic) removeNeighbor(protocol tcpip.NetworkProtocolNumber, addr tcpip.Ad func (n *nic) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { - return linkRes.neighborTable.removeAll() + linkRes.neigh.clear() + return nil } return &tcpip.ErrNotSupported{} @@ -923,7 +904,7 @@ func (n *nic) Name() string { // nudConfigs gets the NUD configurations for n. func (n *nic) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { - return linkRes.neighborTable.nudConfig() + return linkRes.neigh.config(), nil } return NUDConfigurations{}, &tcpip.ErrNotSupported{} @@ -936,7 +917,8 @@ func (n *nic) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfiguration func (n *nic) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { c.resetInvalidFields() - return linkRes.neighborTable.setNUDConfig(c) + linkRes.neigh.setConfig(c) + return nil } return &tcpip.ErrNotSupported{} @@ -979,7 +961,7 @@ func (n *nic) isValidForOutgoing(ep AssignableAddressEndpoint) bool { // HandleNeighborProbe implements NetworkInterface. func (n *nic) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { if l, ok := n.linkAddrResolvers[protocol]; ok { - l.neighborTable.handleProbe(addr, linkAddr) + l.neigh.handleProbe(addr, linkAddr) return nil } @@ -989,7 +971,7 @@ func (n *nic) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcp // HandleNeighborConfirmation implements NetworkInterface. func (n *nic) HandleNeighborConfirmation(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) tcpip.Error { if l, ok := n.linkAddrResolvers[protocol]; ok { - l.neighborTable.handleConfirmation(addr, linkAddr, flags) + l.neigh.handleConfirmation(addr, linkAddr, flags) return nil } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index e9acef6a2..e1253f310 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -101,7 +101,6 @@ func TestNUDFunctions(t *testing.T) { clock := faketime.NewManualClock() s := stack.New(stack.Options{ NUDConfigs: stack.DefaultNUDConfigurations(), - UseNeighborCache: true, NetworkProtocols: test.netProtoFactory, Clock: clock, }) @@ -206,7 +205,6 @@ func TestDefaultNUDConfigurations(t *testing.T) { // address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: stack.DefaultNUDConfigurations(), - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -261,7 +259,6 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -318,7 +315,6 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -398,7 +394,6 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -460,7 +455,6 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -512,7 +506,6 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -564,7 +557,6 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -616,7 +608,6 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 2ce1a65c6..e946f9fe3 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -53,7 +53,7 @@ type Route struct { // linkRes is set if link address resolution is enabled for this protocol on // the route's NIC. - linkRes linkResolver + linkRes *linkResolver } type routeInfo struct { @@ -184,7 +184,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA return r } - if r.linkRes.resolver == nil { + if r.linkRes == nil { return r } @@ -400,7 +400,7 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isResolutionRequiredRLocked() bool { - return len(r.mu.remoteLinkAddress) == 0 && r.linkRes.resolver != nil && r.isValidForOutgoingRLocked() && !r.local() + return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local() } func (r *Route) isValidForOutgoing() bool { @@ -528,7 +528,7 @@ func (r *Route) IsOutboundBroadcast() bool { // "Reachable" is defined as having full-duplex communication between the // local and remote ends of the route. func (r *Route) ConfirmReachable() { - if r.linkRes.resolver != nil { + if r.linkRes != nil { r.linkRes.confirmReachable(r.nextHop()) } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 07cd88a6a..46c1817ea 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -434,12 +434,6 @@ type Stack struct { // nudConfigs is the default NUD configurations used by interfaces. nudConfigs NUDConfigurations - // useNeighborCache indicates whether ARP and NDP packets should be handled - // by the NIC's neighborCache instead of linkAddrCache. - // - // TODO(gvisor.dev/issue/4658): Remove this field. - useNeighborCache bool - // nudDisp is the NUD event dispatcher that is used to send the netstack // integrator NUD related events. nudDisp NUDDispatcher @@ -516,17 +510,6 @@ type Options struct { // NUDConfigs is the default NUD configurations used by interfaces. NUDConfigs NUDConfigurations - // UseNeighborCache is unused. - // - // TODO(gvisor.dev/issue/4658): Remove this field. - UseNeighborCache bool - - // UseLinkAddrCache indicates that the legacy link address cache should be - // used for link resolution. - // - // TODO(gvisor.dev/issue/4658): Remove this field. - UseLinkAddrCache bool - // NUDDisp is the NUD event dispatcher that an integrator can provide to // receive NUD related events. NUDDisp NUDDispatcher @@ -666,7 +649,6 @@ func New(opts Options) *Stack { icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), nudConfigs: opts.NUDConfigs, - useNeighborCache: !opts.UseLinkAddrCache, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, randomGenerator: mathrand.New(randSrc), diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index b3386f705..62066b3aa 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -4324,7 +4324,6 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - UseNeighborCache: true, Clock: clock, }) e := channel.New(0, 0, "") diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index f2301a9e6..553ec950f 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -487,30 +487,25 @@ func TestGetLinkAddress(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - for _, useNeighborCache := range []bool{true, false} { - t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - UseNeighborCache: useNeighborCache, - } + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + } - host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) + host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) - ch := make(chan stack.LinkResolutionResult, 1) - err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) { - ch <- r - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) - } - wantRes := stack.LinkResolutionResult{Success: test.expectedOk} - if test.expectedOk { - wantRes.LinkAddress = linkAddr2 - } - if diff := cmp.Diff(wantRes, <-ch); diff != "" { - t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) - } - }) + ch := make(chan stack.LinkResolutionResult, 1) + err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) { + ch <- r + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) + } + wantRes := stack.LinkResolutionResult{Success: test.expectedOk} + if test.expectedOk { + wantRes.LinkAddress = linkAddr2 + } + if diff := cmp.Diff(wantRes, <-ch); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) } }) } @@ -587,66 +582,61 @@ func TestRouteResolvedFields(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - for _, useNeighborCache := range []bool{true, false} { - t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - UseNeighborCache: useNeighborCache, - } + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + } - host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) - r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) - if err != nil { - t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) - } - defer r.Release() - - var wantRouteInfo stack.RouteInfo - wantRouteInfo.LocalLinkAddress = linkAddr1 - wantRouteInfo.LocalAddress = test.localAddr - wantRouteInfo.RemoteAddress = test.remoteAddr - wantRouteInfo.NetProto = test.netProto - wantRouteInfo.Loop = stack.PacketOut - wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr - - ch := make(chan stack.ResolvedFieldsResult, 1) - - if !test.immediatelyResolvable { - wantUnresolvedRouteInfo := wantRouteInfo - wantUnresolvedRouteInfo.RemoteLinkAddress = "" - - err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { - ch <- r - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) - } - if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { - t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) - } + host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) + r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) + if err != nil { + t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) + } + defer r.Release() - if !test.expectedSuccess { - return - } + var wantRouteInfo stack.RouteInfo + wantRouteInfo.LocalLinkAddress = linkAddr1 + wantRouteInfo.LocalAddress = test.localAddr + wantRouteInfo.RemoteAddress = test.remoteAddr + wantRouteInfo.NetProto = test.netProto + wantRouteInfo.Loop = stack.PacketOut + wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr - // At this point the neighbor table should be populated so the route - // should be immediately resolvable. - } + ch := make(chan stack.ResolvedFieldsResult, 1) - if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { - ch <- r - }); err != nil { - t.Errorf("r.ResolvedFields(_): %s", err) - } - select { - case routeResolveRes := <-ch: - if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { - t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected route to be immediately resolvable") - } + if !test.immediatelyResolvable { + wantUnresolvedRouteInfo := wantRouteInfo + wantUnresolvedRouteInfo.RemoteLinkAddress = "" + + err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { + ch <- r }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) + } + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) + } + + if !test.expectedSuccess { + return + } + + // At this point the neighbor table should be populated so the route + // should be immediately resolvable. + } + + if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { + ch <- r + }); err != nil { + t.Errorf("r.ResolvedFields(_): %s", err) + } + select { + case routeResolveRes := <-ch: + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected route to be immediately resolvable") } }) } @@ -1065,7 +1055,6 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, Clock: clock, - UseNeighborCache: true, } host1StackOpts := stackOpts host1StackOpts.NUDDisp = &nudDisp |