diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-01-31 11:31:55 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-01-31 11:33:46 -0800 |
commit | 4ee8cf8734d24c7ba78700c21dff561207d4ed1a (patch) | |
tree | 498bd24ce2efd684df021ee4b6d814530457edaa | |
parent | daeb06d2cbf5509bd53dc67138504e51d0fcfae8 (diff) |
Use different neighbor tables per network endpoint
This stores each protocol's neighbor state separately.
This change also removes the need for each neighbor entry to keep
track of their own link address resolver now that all the entries
in a cache will use the same resolver.
PiperOrigin-RevId: 354818155
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/network/arp/arp_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 38 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache_test.go | 95 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache.go | 23 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache_test.go | 434 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry.go | 61 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry_test.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 150 | ||||
-rw-r--r-- | pkg/tcpip/stack/nud_test.go | 218 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 24 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 57 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/link_resolution_test.go | 4 |
21 files changed, 669 insertions, 566 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 5fd4c5574..0d7fadc31 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -148,7 +148,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { remoteAddr := tcpip.Address(h.ProtocolAddressSender()) remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.nic.HandleNeighborProbe(remoteAddr, remoteLinkAddr, e) + switch err := e.nic.HandleNeighborProbe(header.IPv4ProtocolNumber, remoteAddr, remoteLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ARP but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) + } respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, @@ -190,7 +196,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // The solicited, override, and isRouter flags are not available for ARP; // they are only available for IPv6 Neighbor Advertisements. - e.nic.HandleNeighborConfirmation(addr, linkAddr, stack.ReachabilityConfirmationFlags{ + switch err := e.nic.HandleNeighborConfirmation(header.IPv4ProtocolNumber, addr, linkAddr, stack.ReachabilityConfirmationFlags{ // Solicited and unsolicited (also referred to as gratuitous) ARP Replies // are handled equivalently to a solicited Neighbor Advertisement. Solicited: true, @@ -199,7 +205,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { Override: false, // ARP does not distinguish between router and non-router hosts. IsRouter: false, - }) + }); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ARP but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor confirmation message: %s", err)) + } } } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index d753a97af..24357e15d 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -491,9 +491,9 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { t.Fatal(err) } - neighbors, err := c.s.Neighbors(nicID) + neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber) if err != nil { - t.Fatalf("c.s.Neighbors(%d): %s", nicID, err) + t.Fatalf("c.s.Neighbors(%d, %d): %s", nicID, ipv4.ProtocolNumber, err) } neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 291330e8e..8d155344b 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -311,10 +311,12 @@ func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.N return &tcpip.ErrNotSupported{} } -func (*testInterface) HandleNeighborProbe(tcpip.Address, tcpip.LinkAddress, stack.LinkAddressResolver) { +func (*testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { + return nil } -func (*testInterface) HandleNeighborConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) { +func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { + return nil } func TestSourceAddressValidation(t *testing.T) { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index bdc88fe5d..12e5ead5e 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -290,7 +290,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { received.invalid.Increment() return } else { - e.nic.HandleNeighborProbe(srcAddr, sourceLinkAddr, e) + switch err := e.nic.HandleNeighborProbe(ProtocolNumber, srcAddr, sourceLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) + } } // As per RFC 4861 section 7.1.1: @@ -456,11 +462,17 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // If the NA message has the target link layer option, update the link // address cache with the link address for the target of the message. - e.nic.HandleNeighborConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ + switch err := e.nic.HandleNeighborConfirmation(ProtocolNumber, targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ Solicited: na.SolicitedFlag(), Override: na.OverrideFlag(), IsRouter: na.RouterFlag(), - }) + }); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor confirmation message: %s", err)) + } case header.ICMPv6EchoRequest: received.echoRequest.Increment() @@ -566,9 +578,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { return } - // A RS with a specified source IP address modifies the NUD state - // machine in the same way a reachability probe would. - e.nic.HandleNeighborProbe(srcAddr, sourceLinkAddr, e) + // A RS with a specified source IP address modifies the neighbor table + // in the same way a regular probe would. + switch err := e.nic.HandleNeighborProbe(ProtocolNumber, srcAddr, sourceLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) + } } case header.ICMPv6RouterAdvert: @@ -617,7 +635,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // If the RA has the source link layer option, update the link address // cache with the link address for the advertised router. if len(sourceLinkAddr) != 0 { - e.nic.HandleNeighborProbe(routerAddr, sourceLinkAddr, e) + switch err := e.nic.HandleNeighborProbe(ProtocolNumber, routerAddr, sourceLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) + } } e.mu.Lock() diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 755293377..4374d0198 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -139,12 +139,14 @@ func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gs return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) } -func (t *testInterface) HandleNeighborProbe(tcpip.Address, tcpip.LinkAddress, stack.LinkAddressResolver) { +func (t *testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { t.probeCount++ + return nil } -func (t *testInterface) HandleNeighborConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) { +func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { t.confirmationCount++ + return nil } func TestICMPCounts(t *testing.T) { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 4cc81e6cc..e0245487b 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -338,18 +338,18 @@ func TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *test Data: hdr.View().ToVectorisedView(), })) - neighbors, err := s.Neighbors(nicID) + neighbors, err := s.Neighbors(nicID, ProtocolNumber) if err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) } neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) for _, n := range neighbors { if existing, ok := neighborByAddr[n.Addr]; ok { if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff) } - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing) } neighborByAddr[n.Addr] = n } @@ -907,18 +907,18 @@ func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *tes Data: hdr.View().ToVectorisedView(), })) - neighbors, err := s.Neighbors(nicID) + neighbors, err := s.Neighbors(nicID, ProtocolNumber) if err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) } neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) for _, n := range neighbors { if existing, ok := neighborByAddr[n.Addr]; ok { if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff) } - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing) } neighborByAddr[n.Addr] = n } @@ -1277,8 +1277,8 @@ func TestNeighborAdvertisementValidation(t *testing.T) { // There is no need to create an entry if none exists, since the // recipient has apparently not initiated any communication with the // target. - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) + if neighbors, err := s.Neighbors(nicID, ProtocolNumber); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) } else if len(neighbors) != 0 { t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 704812641..c24f56ece 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -400,7 +400,10 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC if !ok { t.Fatal("NIC 2 does not exist") } - proto.neighborTable = nic.neighborTable + + if l, ok := nic.linkAddrResolvers[fwdTestNetNumber]; ok { + proto.neighborTable = l.neighborTable + } // Route all packets to NIC 2. { diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 4504db752..5b6b58b1d 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -32,6 +32,8 @@ const linkAddrCacheSize = 512 // max cache entries type linkAddrCache struct { nic *NIC + linkRes LinkAddressResolver + // ageLimit is how long a cache entry is valid for. ageLimit time.Duration @@ -196,10 +198,10 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { return entry } -// get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { +// get reports any known link address for addr. +func (c *linkAddrCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { c.mu.Lock() - entry := c.getOrCreateEntryLocked(k) + entry := c.getOrCreateEntryLocked(addr) entry.mu.Lock() defer entry.mu.Unlock() c.mu.Unlock() @@ -222,7 +224,7 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA } if entry.mu.done == nil { entry.mu.done = make(chan struct{}) - go c.startAddressResolution(k, linkRes, localAddr, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + go c.startAddressResolution(addr, localAddr, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{} default: @@ -230,11 +232,11 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA } } -func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.Address, localAddr tcpip.Address, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */) + c.linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */) select { case now := <-time.After(c.resolutionTimeout): @@ -278,15 +280,18 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt return true } -func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { - c := &linkAddrCache{ +func (c *linkAddrCache) init(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int, linkRes LinkAddressResolver) { + *c = linkAddrCache{ nic: nic, + linkRes: linkRes, ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, } + + c.mu.Lock() c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize) - return c + c.mu.Unlock() } var _ neighborTable = (*linkAddrCache)(nil) @@ -307,7 +312,7 @@ func (*linkAddrCache) removeAll() tcpip.Error { return &tcpip.ErrNotSupported{} } -func (c *linkAddrCache) handleProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress, _ LinkAddressResolver) { +func (c *linkAddrCache) handleProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress) { if len(linkAddr) != 0 { // NUD allows probes without a link address but linkAddrCache // is a simple neighbor table which does not implement NUD. diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 4df6f9265..9e7f331c9 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -77,10 +77,10 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe return 1 } -func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolver) (tcpip.LinkAddress, tcpip.Error) { +func getBlocking(c *linkAddrCache, addr tcpip.Address) (tcpip.LinkAddress, tcpip.Error) { var attemptedResolution bool for { - got, ch, err := c.get(addr, linkRes, "", nil) + got, ch, err := c.get(addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); ok { if attemptedResolution { return got, &tcpip.ErrTimeout{} @@ -100,27 +100,28 @@ func newEmptyNIC() *NIC { } func TestCacheOverflow(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) + var c linkAddrCache + c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) for i := len(testAddrs) - 1; i >= 0; i-- { e := testAddrs[i] c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("insert %d, c.get(%s, nil, '', nil): %s", i, e.addr, err) + t.Errorf("insert %d, c.get(%s, '', nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("insert %d, got c.get(%s, nil, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) + t.Errorf("insert %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // Expect to find at least half of the most recent entries. for i := 0; i < linkAddrCacheSize/2; i++ { e := testAddrs[i] - got, _, err := c.get(e.addr, nil, "", nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("check %d, c.get(%s, nil, '', nil): %s", i, e.addr, err) + t.Errorf("check %d, c.get(%s, '', nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("check %d, got c.get(%s, nil, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) + t.Errorf("check %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // The earliest entries should no longer be in the cache. @@ -135,8 +136,9 @@ func TestCacheOverflow(t *testing.T) { } func TestCacheConcurrent(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, linkRes) var wg sync.WaitGroup for r := 0; r < 16; r++ { @@ -154,12 +156,12 @@ func TestCacheConcurrent(t *testing.T) { // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, linkRes, "", nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, _, '', nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, _, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } e = testAddrs[0] @@ -171,38 +173,40 @@ func TestCacheConcurrent(t *testing.T) { } func TestCacheAgeLimit(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3, linkRes) e := testAddrs[0] c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - _, _, err := c.get(e.addr, linkRes, "", nil) + _, _, err := c.get(e.addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.get(%s, _, '', nil) = %s, want = ErrWouldBlock", e.addr, err) + t.Errorf("got c.get(%s, '', nil) = %s, want = ErrWouldBlock", e.addr, err) } } func TestCacheReplace(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) + var c linkAddrCache + c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) e := testAddrs[0] l2 := e.linkAddr + "2" c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, nil, '', nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, nil, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } c.add(e.addr, l2) - got, _, err = c.get(e.addr, nil, "", nil) + got, _, err = c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, nil, '', nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != l2 { - t.Errorf("got c.get(%s, nil, '', nil) = %s, want = %s", e.addr, got, l2) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, l2) } } @@ -213,34 +217,36 @@ func TestCacheResolution(t *testing.T) { // // Using a large resolution timeout decreases the probability of experiencing // this race condition and does not affect how long this test takes to run. - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1, linkRes) for i, ta := range testAddrs { - got, err := getBlocking(c, ta.addr, linkRes) + got, err := getBlocking(&c, ta.addr) if err != nil { - t.Errorf("check %d, getBlocking(_, %s, _): %s", i, ta.addr, err) + t.Errorf("check %d, getBlocking(_, %s): %s", i, ta.addr, err) } if got != ta.linkAddr { - t.Errorf("check %d, got getBlocking(_, %s, _) = %s, want = %s", i, ta.addr, got, ta.linkAddr) + t.Errorf("check %d, got getBlocking(_, %s) = %s, want = %s", i, ta.addr, got, ta.linkAddr) } } // Check that after resolved, address stays in the cache and never returns WouldBlock. for i := 0; i < 10; i++ { e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, linkRes, "", nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, _, '', nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, _, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } } } func TestCacheResolutionFailed(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5, linkRes) var requestCount uint32 linkRes.onLinkAddressRequest = func() { @@ -249,20 +255,20 @@ func TestCacheResolutionFailed(t *testing.T) { // First, sanity check that resolution is working... e := testAddrs[0] - got, err := getBlocking(c, e.addr, linkRes) + got, err := getBlocking(&c, e.addr) if err != nil { - t.Errorf("getBlocking(_, %s, _): %s", e.addr, err) + t.Errorf("getBlocking(_, %s): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got getBlocking(_, %s, _) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got getBlocking(_, %s) = %s, want = %s", e.addr, got, e.linkAddr) } before := atomic.LoadUint32(&requestCount) e.addr += "2" - a, err := getBlocking(c, e.addr, linkRes) + a, err := getBlocking(&c, e.addr) if _, ok := err.(*tcpip.ErrTimeout); !ok { - t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) + t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { @@ -273,12 +279,13 @@ func TestCacheResolutionFailed(t *testing.T) { func TestCacheResolutionTimeout(t *testing.T) { resolverDelay := 500 * time.Millisecond expiration := resolverDelay / 10 - c := newLinkAddrCache(newEmptyNIC(), expiration, 1*time.Millisecond, 3) - linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c, delay: resolverDelay} + c.init(newEmptyNIC(), expiration, 1*time.Millisecond, 3, linkRes) e := testAddrs[0] - a, err := getBlocking(c, e.addr, linkRes) + a, err := getBlocking(&c, e.addr) if _, ok := err.(*tcpip.ErrTimeout); !ok { - t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) + t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index c13be137e..0238605af 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -2796,8 +2796,8 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN NIC: nicID, }}) - if err := s.AddStaticNeighbor(nicID, llAddr3, linkAddr3); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) + if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err) } return ndpDisp, e, s } diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 64b8046f5..7e3132058 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -43,8 +43,9 @@ type NeighborStats struct { // Their state is always Static. The amount of static entries stored in the // cache is unbounded. type neighborCache struct { - nic *NIC - state *NUDState + nic *NIC + state *NUDState + linkRes LinkAddressResolver // mu protects the fields below. mu sync.RWMutex @@ -69,7 +70,7 @@ type neighborCache struct { // reset to state incomplete, and returned. If no matching entry exists and the // cache is not full, a new entry with state incomplete is allocated and // returned. -func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { +func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address) *neighborEntry { n.mu.Lock() defer n.mu.Unlock() @@ -85,7 +86,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA // The entry that needs to be created must be dynamic since all static // entries are directly added to the cache via addStaticEntry. - entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes) + entry := newNeighborEntry(n, remoteAddr, n.state) if n.dynamic.count == neighborCacheSize { e := n.dynamic.lru.Back() e.mu.Lock() @@ -122,8 +123,8 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA // packet prompting NUD/link address resolution. // // TODO(gvisor.dev/issue/5151): Don't return the neighbor entry. -func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) { - entry := n.getOrCreateEntry(remoteAddr, linkRes) +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) { + entry := n.getOrCreateEntry(remoteAddr) entry.mu.Lock() defer entry.mu.Unlock() @@ -202,7 +203,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd entry.mu.Unlock() } - n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) + n.cache[addr] = newStaticNeighborEntry(n, addr, linkAddr, n.state) } // removeEntry removes a dynamic or static entry by address from the neighbor @@ -265,8 +266,8 @@ func (n *neighborCache) neighbors() ([]NeighborEntry, tcpip.Error) { return n.entries(), nil } -func (n *neighborCache) get(addr tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - entry, ch, err := n.entry(addr, localAddr, linkRes, onResolve) +func (n *neighborCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { + entry, ch, err := n.entry(addr, localAddr, onResolve) return entry.LinkAddr, ch, err } @@ -286,8 +287,8 @@ func (n *neighborCache) removeAll() tcpip.Error { // handleProbe handles a neighbor probe as defined by RFC 4861 section 7.2.3. // // Validation of the probe is expected to be handled by the caller. -func (n *neighborCache) handleProbe(remoteAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { - entry := n.getOrCreateEntry(remoteAddr, linkRes) +func (n *neighborCache) handleProbe(remoteAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + entry := n.getOrCreateEntry(remoteAddr) entry.mu.Lock() entry.handleProbeLocked(remoteLinkAddr) entry.mu.Unlock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 122888fcf..b489b5e08 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -76,10 +76,15 @@ func entryDiffOptsWithSort() []cmp.Option { })) } -func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { +func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *testNeighborResolver { config.resetInvalidFields() rng := rand.New(rand.NewSource(time.Now().UnixNano())) - neigh := &neighborCache{ + linkRes := &testNeighborResolver{ + clock: clock, + entries: newTestEntryStore(), + delay: typicalLatency, + } + linkRes.neigh = &neighborCache{ nic: &NIC{ stack: &Stack{ clock: clock, @@ -88,10 +93,11 @@ func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock id: 1, stats: makeNICStats(), }, - state: NewNUDState(config, rng), - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + state: NewNUDState(config, rng), + linkRes: linkRes, + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), } - return neigh + return linkRes } // testEntryStore contains a set of IP to NeighborEntry mappings. @@ -241,10 +247,10 @@ func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) + linkRes := newTestNeighborResolver(&nudDisp, c, clock) - if got, want := neigh.config(), c; got != want { - t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + if got, want := linkRes.neigh.config(), c; got != want { + t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) } // No events should have been dispatched. @@ -259,14 +265,14 @@ func TestNeighborCacheSetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) + linkRes := newTestNeighborResolver(&nudDisp, c, clock) c.MinRandomFactor = 1 c.MaxRandomFactor = 1 - neigh.setConfig(c) + linkRes.neigh.setConfig(c) - if got, want := neigh.config(), c; got != want { - t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + if got, want := linkRes.neigh.config(), c; got != want { + t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) } // No events should have been dispatched. @@ -281,22 +287,15 @@ func TestNeighborCacheEntry(t *testing.T) { c := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, c, clock) - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -328,8 +327,8 @@ func TestNeighborCacheEntry(t *testing.T) { t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + if _, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err != nil { + t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } // No more events should have been dispatched. @@ -345,23 +344,16 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -393,7 +385,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } - neigh.removeEntry(entry.Addr) + linkRes.neigh.removeEntry(entry.Addr) { wantEvents := []testEntryEventInfo{ @@ -416,17 +408,15 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } { - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } } } type testContext struct { clock *faketime.ManualClock - neigh *neighborCache - store *testEntryStore linkRes *testNeighborResolver nudDisp *testNUDDispatcher } @@ -434,19 +424,10 @@ type testContext struct { func newTestContext(c NUDConfigurations) testContext { nudDisp := &testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(nudDisp, c, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(nudDisp, c, clock) return testContext{ clock: clock, - neigh: neigh, - store: store, linkRes: linkRes, nudDisp: nudDisp, } @@ -460,17 +441,17 @@ type overflowOptions struct { func (c *testContext) overflowCache(opts overflowOptions) error { // Fill the neighbor cache to capacity to verify the LRU eviction strategy is // working properly after the entry removal. - for i := opts.startAtEntryIndex; i < c.store.size(); i++ { + for i := opts.startAtEntryIndex; i < c.linkRes.entries.size(); i++ { // Add a new entry - entry, ok := c.store.entry(i) + entry, ok := c.linkRes.entries.entry(i) if !ok { - return fmt.Errorf("c.store.entry(%d) not found", i) + return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + return fmt.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } - c.clock.Advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) var wantEvents []testEntryEventInfo @@ -478,9 +459,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error { // LRU eviction strategy. Note that the number of static entries should not // affect the total number of dynamic entries that can be added. if i >= neighborCacheSize+opts.startAtEntryIndex { - removedEntry, ok := c.store.entry(i - neighborCacheSize) + removedEntry, ok := c.linkRes.entries.entry(i - neighborCacheSize) if !ok { - return fmt.Errorf("store.entry(%d) not found", i-neighborCacheSize) + return fmt.Errorf("linkRes.entries.entry(%d) not found", i-neighborCacheSize) } wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestRemoved, @@ -523,10 +504,10 @@ func (c *testContext) overflowCache(opts overflowOptions) error { // by entries() is nondeterministic, so entries have to be sorted before // comparison. wantUnsortedEntries := opts.wantStaticEntries - for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ { - entry, ok := c.store.entry(i) + for i := c.linkRes.entries.size() - neighborCacheSize; i < c.linkRes.entries.size(); i++ { + entry, ok := c.linkRes.entries.entry(i) if !ok { - return fmt.Errorf("c.store.entry(%d) not found", i) + return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) } wantEntry := NeighborEntry{ Addr: entry.Addr, @@ -536,7 +517,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(wantUnsortedEntries, c.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, c.linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } @@ -580,15 +561,15 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { c := newTestContext(config) // Add a dynamic entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } - c.clock.Advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -617,7 +598,7 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { } // Remove the entry - c.neigh.removeEntry(entry.Addr) + c.linkRes.neigh.removeEntry(entry.Addr) { wantEvents := []testEntryEventInfo{ @@ -656,12 +637,12 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { c := newTestContext(config) // Add a static entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -682,7 +663,7 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { } // Remove the static entry that was just added - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) // No more events should have been dispatched. c.nudDisp.mu.Lock() @@ -700,12 +681,12 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) c := newTestContext(config) // Add a static entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -727,7 +708,7 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) // Add a duplicate entry with a different link address staticLinkAddr += "duplicate" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) { wantEvents := []testEntryEventInfo{ { @@ -762,12 +743,12 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { c := newTestContext(config) // Add a static entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -788,7 +769,7 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { } // Remove the static entry that was just added - c.neigh.removeEntry(entry.Addr) + c.linkRes.neigh.removeEntry(entry.Addr) { wantEvents := []testEntryEventInfo{ { @@ -832,13 +813,13 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { c := newTestContext(config) // Add a dynamic entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -870,7 +851,7 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { // Override the entry with a static one using the same address staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) { wantEvents := []testEntryEventInfo{ { @@ -925,14 +906,14 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { c := newTestContext(config) - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) - e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + c.linkRes.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) + e, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -940,7 +921,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { State: Static, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ @@ -982,23 +963,16 @@ func TestNeighborCacheClear(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) // Add a dynamic entry. - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -1030,7 +1004,7 @@ func TestNeighborCacheClear(t *testing.T) { } // Add a static entry. - neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) + linkRes.neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) { wantEvents := []testEntryEventInfo{ @@ -1054,7 +1028,7 @@ func TestNeighborCacheClear(t *testing.T) { } // Clear should remove both dynamic and static entries. - neigh.clear() + linkRes.neigh.clear() // Remove events dispatched from clear() have no deterministic order so they // need to be sorted beforehand. @@ -1098,13 +1072,13 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { c := newTestContext(config) // Add a dynamic entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -1135,7 +1109,7 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { } // Clear the cache. - c.neigh.clear() + c.linkRes.neigh.clear() { wantEvents := []testEntryEventInfo{ { @@ -1174,18 +1148,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) - frequentlyUsedEntry, ok := store.entry(0) + frequentlyUsedEntry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } // The following logic is very similar to overflowCache, but @@ -1193,23 +1160,23 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // Fill the neighbor cache to capacity for i := 0; i < neighborCacheSize; i++ { - entry, ok := store.entry(i) + entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("store.entry(%d) not found", i) + t.Fatalf("linkRes.entries.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { @@ -1240,38 +1207,38 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Keep adding more entries - for i := neighborCacheSize; i < store.size(); i++ { + for i := neighborCacheSize; i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { - if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", frequentlyUsedEntry.Addr, err) + if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { + t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", frequentlyUsedEntry.Addr, err) } } - entry, ok := store.entry(i) + entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("store.entry(%d) not found", i) + t.Fatalf("linkRes.entries.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy - removedEntry, ok := store.entry(i - neighborCacheSize + 1) + removedEntry, ok := linkRes.entries.entry(i - neighborCacheSize + 1) if !ok { - t.Fatalf("store.entry(%d) not found", i-neighborCacheSize+1) + t.Fatalf("linkRes.entries.entry(%d) not found", i-neighborCacheSize+1) } wantEvents := []testEntryEventInfo{ { @@ -1321,10 +1288,10 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { }, } - for i := store.size() - neighborCacheSize + 1; i < store.size(); i++ { - entry, ok := store.entry(i) + for i := linkRes.entries.size() - neighborCacheSize + 1; i < linkRes.entries.size(); i++ { + entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("store.entry(%d) not found", i) + t.Fatalf("linkRes.entries.entry(%d) not found", i) } wantEntry := NeighborEntry{ Addr: entry.Addr, @@ -1334,7 +1301,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } @@ -1353,26 +1320,19 @@ func TestNeighborCacheConcurrent(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) - storeEntries := store.entries() + storeEntries := linkRes.entries.entries() for _, entry := range storeEntries { var wg sync.WaitGroup for r := 0; r < concurrentProcesses; r++ { wg.Add(1) go func(entry NeighborEntry) { defer wg.Done() - switch e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err.(type) { + switch e, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err.(type) { case nil, *tcpip.ErrWouldBlock: default: - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{}) } }(entry) } @@ -1390,10 +1350,10 @@ func TestNeighborCacheConcurrent(t *testing.T) { // The order of entries reported by entries() is nondeterministic, so entries // have to be sorted before comparison. var wantUnsortedEntries []NeighborEntry - for i := store.size() - neighborCacheSize; i < store.size(); i++ { - entry, ok := store.entry(i) + for i := linkRes.entries.size() - neighborCacheSize; i < linkRes.entries.size(); i++ { + entry, ok := linkRes.entries.entry(i) if !ok { - t.Errorf("store.entry(%d) not found", i) + t.Errorf("linkRes.entries.entry(%d) not found", i) } wantEntry := NeighborEntry{ Addr: entry.Addr, @@ -1403,7 +1363,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } } @@ -1413,41 +1373,34 @@ func TestNeighborCacheReplace(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) // Add an entry - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } // Verify the entry exists { - e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) } if t.Failed() { t.FailNow() @@ -1458,21 +1411,21 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } // Notify of a link address change var updatedLinkAddr tcpip.LinkAddress { - entry, ok := store.entry(1) + entry, ok := linkRes.entries.entry(1) if !ok { - t.Fatal("store.entry(1) not found") + t.Fatal("linkRes.entries.entry(1) not found") } updatedLinkAddr = entry.LinkAddr } - store.set(0, updatedLinkAddr) - neigh.handleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ + linkRes.entries.set(0, updatedLinkAddr) + linkRes.neigh.handleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, @@ -1482,9 +1435,9 @@ func TestNeighborCacheReplace(t *testing.T) { // // Verify the entry's new link address and the new state. { - e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Fatalf("neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1492,17 +1445,17 @@ func TestNeighborCacheReplace(t *testing.T) { State: Delay, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } // Verify that the neighbor is now reachable. { - e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1510,7 +1463,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } } @@ -1520,46 +1473,39 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() + linkRes := newTestNeighborResolver(&nudDisp, config, clock) var requestCount uint32 - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - onLinkAddressRequest: func() { - atomic.AddUint32(&requestCount, 1) - }, + linkRes.onLinkAddressRequest = func() { + atomic.AddUint32(&requestCount, 1) } - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } // First, sanity check that resolution is working { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } - got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + got, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1567,7 +1513,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { State: Reachable, } if diff := cmp.Diff(want, got, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } // Verify address resolution fails for an unknown address. @@ -1575,24 +1521,24 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { entry.Addr += "2" { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } - maxAttempts := neigh.config().MaxUnicastProbes + maxAttempts := linkRes.neigh.config().MaxUnicastProbes if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want { t.Errorf("got link address request count = %d, want = %d", got, want) } @@ -1606,27 +1552,22 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { config.RetransmitTimer = time.Millisecond // small enough to cause timeout clock := faketime.NewManualClock() - neigh := newTestNeighborCache(nil, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: time.Minute, // large enough to cause timeout - } + linkRes := newTestNeighborResolver(nil, config, clock) + // large enough to cause timeout + linkRes.delay = time.Minute - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1634,7 +1575,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } @@ -1643,31 +1584,24 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { func TestNeighborCacheRetryResolution(t *testing.T) { config := DefaultNUDConfigurations() clock := faketime.NewManualClock() - neigh := newTestNeighborCache(nil, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - // Simulate a faulty link. - dropReplies: true, - } + linkRes := newTestNeighborResolver(nil, config, clock) + // Simulate a faulty link. + linkRes.dropReplies = true - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } // Perform address resolution with a faulty link, which will fail. { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1675,7 +1609,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) { select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } @@ -1687,20 +1621,20 @@ func TestNeighborCacheRetryResolution(t *testing.T) { State: Failed, }, } - if diff := cmp.Diff(neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) } // Retry address resolution with a working link. linkRes.dropReplies = false { - incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + incompleteEntry, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } if incompleteEntry.State != Incomplete { t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) @@ -1712,9 +1646,9 @@ func TestNeighborCacheRetryResolution(t *testing.T) { if !ok { t.Fatal("expected successful address resolution") } - reachableEntry, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + reachableEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Fatalf("neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) + t.Fatalf("linkRes.neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) } if reachableEntry.Addr != entry.Addr { t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr) @@ -1726,7 +1660,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) { t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String()) } default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } } @@ -1735,42 +1669,36 @@ func BenchmarkCacheClear(b *testing.B) { b.StopTimer() config := DefaultNUDConfigurations() clock := &tcpip.StdClock{} - neigh := newTestNeighborCache(nil, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: 0, - } + linkRes := newTestNeighborResolver(nil, config, clock) + linkRes.delay = 0 // Clear for every possible size of the cache for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { // Fill the neighbor cache to capacity. for i := 0; i < cacheSize; i++ { - entry, ok := store.entry(i) + entry, ok := linkRes.entries.entry(i) if !ok { - b.Fatalf("store.entry(%d) not found", i) + b.Fatalf("linkRes.entries.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + b.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } select { case <-ch: default: - b.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } b.StartTimer() - neigh.clear() + linkRes.neigh.clear() b.StopTimer() } } diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index a037ca6f9..b05f96d4f 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -77,11 +77,7 @@ const ( type neighborEntry struct { neighborEntryEntry - nic *NIC - - // linkRes provides the functionality to send reachability probes, used in - // Neighbor Unreachability Detection. - linkRes LinkAddressResolver + cache *neighborCache // nudState points to the Neighbor Unreachability Detection configuration. nudState *NUDState @@ -106,10 +102,9 @@ type neighborEntry struct { // state, Unknown. Transition out of Unknown by calling either // `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created // neighborEntry. -func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { +func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState *NUDState) *neighborEntry { return &neighborEntry{ - nic: nic, - linkRes: linkRes, + cache: cache, nudState: nudState, neigh: NeighborEntry{ Addr: remoteAddr, @@ -121,18 +116,18 @@ func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, li // newStaticNeighborEntry creates a neighbor cache entry starting at the // Static state. The entry can only transition out of Static by directly // calling `setStateLocked`. -func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { +func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { entry := NeighborEntry{ Addr: addr, LinkAddr: linkAddr, State: Static, - UpdatedAtNanos: nic.stack.clock.NowNanoseconds(), + UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(), } - if nic.stack.nudDisp != nil { - nic.stack.nudDisp.OnNeighborAdded(nic.id, entry) + if nudDisp := cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborAdded(cache.nic.id, entry) } return &neighborEntry{ - nic: nic, + cache: cache, nudState: state, neigh: entry, } @@ -158,7 +153,7 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { // is resolved (which ends up obtaining the entry's lock) while holding the // link resolution queue's lock. Dequeuing packets in a new goroutine avoids // a lock ordering violation. - go e.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) + go e.cache.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) } } @@ -167,8 +162,8 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchAddEventLocked() { - if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborAdded(e.nic.id, e.neigh) + if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborAdded(e.cache.nic.id, e.neigh) } } @@ -177,8 +172,8 @@ func (e *neighborEntry) dispatchAddEventLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchChangeEventLocked() { - if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborChanged(e.nic.id, e.neigh) + if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborChanged(e.cache.nic.id, e.neigh) } } @@ -187,8 +182,8 @@ func (e *neighborEntry) dispatchChangeEventLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchRemoveEventLocked() { - if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) + if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborRemoved(e.cache.nic.id, e.neigh) } } @@ -206,7 +201,7 @@ func (e *neighborEntry) cancelJobLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) removeLocked() { - e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() e.dispatchRemoveEventLocked() e.cancelJobLocked() e.notifyCompletionLocked(false /* succeeded */) @@ -222,7 +217,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { prev := e.neigh.State e.neigh.State = next - e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() config := e.nudState.Config() switch next { @@ -230,14 +225,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev)) case Reachable: - e.job = e.nic.stack.newJob(&e.mu, func() { + e.job = e.cache.nic.stack.newJob(&e.mu, func() { e.setStateLocked(Stale) e.dispatchChangeEventLocked() }) e.job.Schedule(e.nudState.ReachableTime()) case Delay: - e.job = e.nic.stack.newJob(&e.mu, func() { + e.job = e.cache.nic.stack.newJob(&e.mu, func() { e.setStateLocked(Probe) e.dispatchChangeEventLocked() }) @@ -254,14 +249,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr); err != nil { + if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return } retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(config.RetransmitTimer) } @@ -269,7 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // for finishing the state transition. This is necessary to avoid // deadlock where sending and processing probes are done synchronously, // such as loopback and integration tests. - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(immediateDuration) case Failed: @@ -292,12 +287,12 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { case Failed: - e.nic.stats.Neighbor.FailedEntryLookups.Increment() + e.cache.nic.stats.Neighbor.FailedEntryLookups.Increment() fallthrough case Unknown: e.neigh.State = Incomplete - e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() e.dispatchAddEventLocked() @@ -340,7 +335,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // address SHOULD be placed in the IP Source Address of the outgoing // solicitation. // - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, ""); err != nil { + if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, ""); err != nil { // There is no need to log the error here; the NUD implementation may // assume a working link. A valid link should be the responsibility of // the NIC/stack.LinkEndpoint. @@ -350,7 +345,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { } retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe) e.job.Schedule(config.RetransmitTimer) } @@ -358,7 +353,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // for finishing the state transition. This is necessary to avoid // deadlock where sending and processing probes are done synchronously, // such as loopback and integration tests. - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe) e.job.Schedule(immediateDuration) case Stale: @@ -504,7 +499,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // // TODO(gvisor.dev/issue/4085): Remove the special casing we do for IPv6 // here. - ep, ok := e.nic.networkEndpoints[header.IPv6ProtocolNumber] + ep, ok := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber] if !ok { panic(fmt.Sprintf("have a neighbor entry for an IPv6 router but no IPv6 network endpoint")) } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 5e5e0e6ca..57cfbdb8b 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -230,23 +230,30 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e }, stats: makeNICStats(), } + netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ - header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil), + header.IPv6ProtocolNumber: netEP, } rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) - linkRes := entryTestLinkResolver{} - entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes) - + var linkRes entryTestLinkResolver // Stub out the neighbor cache to verify deletion from the cache. neigh := &neighborCache{ - nic: &nic, - state: nudState, - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + nic: &nic, + state: nudState, + linkRes: &linkRes, + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + l := linkResolver{ + resolver: &linkRes, + neighborTable: neigh, } + entry := newNeighborEntry(neigh, entryTestAddr1 /* remoteAddr */, nudState) neigh.cache[entryTestAddr1] = entry - nic.neighborTable = neigh + nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]linkResolver{ + header.IPv6ProtocolNumber: l, + } return entry, &disp, &linkRes, clock } @@ -836,7 +843,7 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) - ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) + ipv6EP := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index c813b0da5..693ea064a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -27,11 +27,11 @@ import ( type neighborTable interface { neighbors() ([]NeighborEntry, tcpip.Error) addStaticEntry(tcpip.Address, tcpip.LinkAddress) - get(addr tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) + get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) remove(tcpip.Address) tcpip.Error removeAll() tcpip.Error - handleProbe(tcpip.Address, tcpip.LinkAddress, LinkAddressResolver) + handleProbe(tcpip.Address, tcpip.LinkAddress) handleConfirmation(tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) handleUpperLevelConfirmation(tcpip.Address) @@ -41,6 +41,20 @@ type neighborTable interface { var _ NetworkInterface = (*NIC)(nil) +type linkResolver struct { + resolver LinkAddressResolver + + neighborTable neighborTable +} + +func (l *linkResolver) getNeighborLinkAddress(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { + return l.neighborTable.get(addr, localAddr, onResolve) +} + +func (l *linkResolver) confirmReachable(addr tcpip.Address) { + l.neighborTable.handleUpperLevelConfirmation(addr) +} + // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { @@ -56,7 +70,7 @@ type NIC struct { // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint - linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + linkAddrResolvers map[tcpip.NetworkProtocolNumber]linkResolver // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. // @@ -67,8 +81,6 @@ type NIC struct { // complete. linkResQueue packetsPendingLinkResolution - neighborTable neighborTable - mu struct { sync.RWMutex spoofing bool @@ -153,25 +165,13 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC context: ctx, stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]linkResolver), } nic.linkResQueue.init(nic) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0 - if resolutionRequired { - if stack.useNeighborCache { - nic.neighborTable = &neighborCache{ - nic: nic, - state: NewNUDState(stack.nudConfigs, stack.randomGenerator), - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), - } - } else { - nic.neighborTable = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts) - } - } - // Register supported packet and network endpoint protocols. for _, netProto := range header.Ethertypes { nic.mu.packetEPs[netProto] = new(packetEndpointList) @@ -185,7 +185,24 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC if resolutionRequired { if r, ok := netEP.(LinkAddressResolver); ok { - nic.linkAddrResolvers[r.LinkAddressProtocol()] = r + l := linkResolver{ + resolver: r, + } + + if stack.useNeighborCache { + l.neighborTable = &neighborCache{ + nic: nic, + state: NewNUDState(stack.nudConfigs, stack.randomGenerator), + linkRes: r, + + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + } else { + cache := new(linkAddrCache) + cache.init(nic, ageLimit, resolutionTimeout, resolutionAttempts, r) + l.neighborTable = cache + } + nic.linkAddrResolvers[r.LinkAddressProtocol()] = l } } } @@ -240,18 +257,19 @@ func (n *NIC) disableLocked() { for _, ep := range n.networkEndpoints { ep.Disable() - } - // Clear the neighbour table (including static entries) as we cannot guarantee - // that the current neighbour table will be valid when the NIC is enabled - // again. - // - // This matches linux's behaviour at the time of writing: - // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371 - switch err := n.clearNeighbors(); err.(type) { - case nil, *tcpip.ErrNotSupported: - default: - panic(fmt.Sprintf("n.clearNeighbors(): %s", err)) + // Clear the neighbour table (including static entries) as we cannot + // guarantee that the current neighbour table will be valid when the NIC is + // enabled again. + // + // This matches linux's behaviour at the time of writing: + // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371 + netProto := ep.NetworkProtocolNumber() + switch err := n.clearNeighbors(netProto); err.(type) { + case nil, *tcpip.ErrNotSupported: + default: + panic(fmt.Sprintf("n.clearNeighbors(%d): %s", netProto, err)) + } } if !n.setEnabled(false) { @@ -604,63 +622,49 @@ func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error { return &tcpip.ErrBadLocalAddress{} } -func (n *NIC) confirmReachable(addr tcpip.Address) { - if n.neighborTable != nil { - n.neighborTable.handleUpperLevelConfirmation(addr) - } -} - func (n *NIC) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error { linkRes, ok := n.linkAddrResolvers[protocol] if !ok { return &tcpip.ErrNotSupported{} } - if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok { + if linkAddr, ok := linkRes.resolver.ResolveStaticAddress(addr); ok { onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true}) return nil } - _, _, err := n.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) + _, _, err := linkRes.getNeighborLinkAddress(addr, localAddr, onResolve) return err } -func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - if n.neighborTable != nil { - return n.neighborTable.get(addr, linkRes, localAddr, onResolve) - } - - return "", nil, &tcpip.ErrNotSupported{} -} - -func (n *NIC) neighbors() ([]NeighborEntry, tcpip.Error) { - if n.neighborTable != nil { - return n.neighborTable.neighbors() +func (n *NIC) neighbors(protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.neighbors() } return nil, &tcpip.ErrNotSupported{} } -func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) tcpip.Error { - if n.neighborTable != nil { - n.neighborTable.addStaticEntry(addr, linkAddress) +func (n *NIC) addStaticNeighbor(addr tcpip.Address, protocol tcpip.NetworkProtocolNumber, linkAddress tcpip.LinkAddress) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + linkRes.neighborTable.addStaticEntry(addr, linkAddress) return nil } return &tcpip.ErrNotSupported{} } -func (n *NIC) removeNeighbor(addr tcpip.Address) tcpip.Error { - if n.neighborTable != nil { - return n.neighborTable.remove(addr) +func (n *NIC) removeNeighbor(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.remove(addr) } return &tcpip.ErrNotSupported{} } -func (n *NIC) clearNeighbors() tcpip.Error { - if n.neighborTable != nil { - return n.neighborTable.removeAll() +func (n *NIC) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.removeAll() } return &tcpip.ErrNotSupported{} @@ -947,9 +951,9 @@ func (n *NIC) Name() string { } // nudConfigs gets the NUD configurations for n. -func (n *NIC) nudConfigs() (NUDConfigurations, tcpip.Error) { - if n.neighborTable != nil { - return n.neighborTable.nudConfig() +func (n *NIC) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.nudConfig() } return NUDConfigurations{}, &tcpip.ErrNotSupported{} @@ -959,10 +963,10 @@ func (n *NIC) nudConfigs() (NUDConfigurations, tcpip.Error) { // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (n *NIC) setNUDConfigs(c NUDConfigurations) tcpip.Error { - if n.neighborTable != nil { +func (n *NIC) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { c.resetInvalidFields() - return n.neighborTable.setNUDConfig(c) + return linkRes.neighborTable.setNUDConfig(c) } return &tcpip.ErrNotSupported{} @@ -1003,15 +1007,21 @@ func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { } // HandleNeighborProbe implements NetworkInterface. -func (n *NIC) HandleNeighborProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { - if n.neighborTable != nil { - n.neighborTable.handleProbe(addr, linkAddr, linkRes) +func (n *NIC) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { + if l, ok := n.linkAddrResolvers[protocol]; ok { + l.neighborTable.handleProbe(addr, linkAddr) + return nil } + + return &tcpip.ErrNotSupported{} } // HandleNeighborConfirmation implements NetworkInterface. -func (n *NIC) HandleNeighborConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { - if n.neighborTable != nil { - n.neighborTable.handleConfirmation(addr, linkAddr, flags) +func (n *NIC) HandleNeighborConfirmation(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) tcpip.Error { + if l, ok := n.linkAddrResolvers[protocol]; ok { + l.neighborTable.handleConfirmation(addr, linkAddr, flags) + return nil } + + return &tcpip.ErrNotSupported{} } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 504acc246..e9acef6a2 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -19,7 +19,9 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -52,66 +54,146 @@ func (f *fakeRand) Float32() float32 { return f.num } -// TestSetNUDConfigurationFailsForBadNICID tests to make sure we get an error if -// we attempt to update NUD configurations using an invalid NICID. -func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The networking - // stack will only allocate neighbor caches if a protocol providing link - // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - UseNeighborCache: true, - }) +func TestNUDFunctions(t *testing.T) { + const nicID = 1 - // No NIC with ID 1 yet. - config := stack.NUDConfigurations{} - err := s.SetNUDConfigurations(1, config) - if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { - t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, &tcpip.ErrUnknownNICID{}) + tests := []struct { + name string + nicID tcpip.NICID + netProtoFactory []stack.NetworkProtocolFactory + extraLinkCapabilities stack.LinkEndpointCapabilities + expectedErr tcpip.Error + }{ + { + name: "Invalid NICID", + nicID: nicID + 1, + netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + extraLinkCapabilities: stack.CapabilityResolutionRequired, + expectedErr: &tcpip.ErrUnknownNICID{}, + }, + { + name: "No network protocol", + nicID: nicID, + expectedErr: &tcpip.ErrNotSupported{}, + }, + { + name: "With IPv6", + nicID: nicID, + netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + expectedErr: &tcpip.ErrNotSupported{}, + }, + { + name: "With resolution capability", + nicID: nicID, + extraLinkCapabilities: stack.CapabilityResolutionRequired, + expectedErr: &tcpip.ErrNotSupported{}, + }, + { + name: "With IPv6 and resolution capability", + nicID: nicID, + netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + extraLinkCapabilities: stack.CapabilityResolutionRequired, + }, } -} -// TestNUDConfigurationFailsForNotSupported tests to make sure we get a -// NotSupported error if we attempt to retrieve or set NUD configurations when -// the stack doesn't support NUD. -// -// The stack will report to not support NUD if a neighbor cache for a given NIC -// is not allocated. The networking stack will only allocate neighbor caches if -// the NIC requires link resolution. -func TestNUDConfigurationFailsForNotSupported(t *testing.T) { - const nicID = 1 + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NUDConfigs: stack.DefaultNUDConfigurations(), + UseNeighborCache: true, + NetworkProtocols: test.netProtoFactory, + Clock: clock, + }) - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities &^= stack.CapabilityResolutionRequired + e := channel.New(0, 0, linkAddr1) + e.LinkEPCapabilities &^= stack.CapabilityResolutionRequired + e.LinkEPCapabilities |= test.extraLinkCapabilities - s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), - UseNeighborCache: true, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - t.Run("Get", func(t *testing.T) { - _, err := s.NUDConfigurations(nicID) - if _, ok := err.(*tcpip.ErrNotSupported); !ok { - t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, &tcpip.ErrNotSupported{}) - } - }) + configs := stack.DefaultNUDConfigurations() + configs.BaseReachableTime = time.Hour - t.Run("Set", func(t *testing.T) { - config := stack.NUDConfigurations{} - err := s.SetNUDConfigurations(nicID, config) - if _, ok := err.(*tcpip.ErrNotSupported); !ok { - t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, &tcpip.ErrNotSupported{}) - } - }) + { + err := s.SetNUDConfigurations(test.nicID, ipv6.ProtocolNumber, configs) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.SetNUDConfigurations(%d, %d, _) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } + } + + { + gotConfigs, err := s.NUDConfigurations(test.nicID, ipv6.ProtocolNumber) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.NUDConfigurations(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } else if test.expectedErr == nil { + if diff := cmp.Diff(configs, gotConfigs); diff != "" { + t.Errorf("got configs mismatch (-want +got):\n%s", diff) + } + } + } + + for _, addr := range []tcpip.Address{llAddr1, llAddr2} { + { + err := s.AddStaticNeighbor(test.nicID, ipv6.ProtocolNumber, addr, linkAddr1) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.AddStaticNeighbor(%d, %d, %s, %s) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, addr, linkAddr1, diff) + } + } + } + + { + wantErr := test.expectedErr + for i := 0; i < 2; i++ { + { + err := s.RemoveNeighbor(test.nicID, ipv6.ProtocolNumber, llAddr1) + if diff := cmp.Diff(wantErr, err); diff != "" { + t.Errorf("s.RemoveNeighbor(%d, %d, '') error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } + } + + if test.expectedErr != nil { + break + } + + // Removing a neighbor that does not exist should give us a bad address + // error. + wantErr = &tcpip.ErrBadAddress{} + } + } + + { + neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } else if test.expectedErr == nil { + if diff := cmp.Diff( + []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + neighbors, + ); diff != "" { + t.Errorf("neighbors mismatch (-want +got):\n%s", diff) + } + } + } + + { + err := s.ClearNeighbors(test.nicID, ipv6.ProtocolNumber) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.ClearNeigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } else if test.expectedErr == nil { + if neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber); err != nil { + t.Errorf("s.Neighbors(%d, %d): %s", test.nicID, ipv6.ProtocolNumber, err) + } else if len(neighbors) != 0 { + t.Errorf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + } + } + } + }) + } } -// TestDefaultNUDConfigurationIsValid verifies that calling -// resetInvalidFields() on the result of DefaultNUDConfigurations() does not -// change anything. DefaultNUDConfigurations() should return a valid -// NUDConfigurations. func TestDefaultNUDConfigurations(t *testing.T) { const nicID = 1 @@ -129,12 +211,12 @@ func TestDefaultNUDConfigurations(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - c, err := s.NUDConfigurations(nicID) + c, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got, want := c, stack.DefaultNUDConfigurations(); got != want { - t.Errorf("got stack.NUDConfigurations(%d) = %+v, want = %+v", nicID, got, want) + t.Errorf("got stack.NUDConfigurations(%d, %d) = %+v, want = %+v", nicID, ipv6.ProtocolNumber, got, want) } } @@ -184,9 +266,9 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.BaseReachableTime; got != test.want { t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want) @@ -241,9 +323,9 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MinRandomFactor; got != test.want { t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want) @@ -321,9 +403,9 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MaxRandomFactor; got != test.want { t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want) @@ -383,9 +465,9 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.RetransmitTimer; got != test.want { t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want) @@ -435,9 +517,9 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.DelayFirstProbeTime; got != test.want { t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want) @@ -487,9 +569,9 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MaxMulticastProbes; got != test.want { t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want) @@ -539,9 +621,9 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MaxUnicastProbes; got != test.want { t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index c652c2bd7..e02f7190c 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -536,11 +536,11 @@ type NetworkInterface interface { // // HandleNeighborProbe assumes that the probe is valid for the network // interface the probe was received on. - HandleNeighborProbe(tcpip.Address, tcpip.LinkAddress, LinkAddressResolver) + HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error // HandleNeighborConfirmation processes an incoming neighbor confirmation // (e.g. ARP reply or NDP Neighbor Advertisement). - HandleNeighborConfirmation(tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) + HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) tcpip.Error } // LinkResolvableNetworkEndpoint handles link resolution events. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 1c8ef6ed4..bab55ce49 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -53,7 +53,7 @@ type Route struct { // linkRes is set if link address resolution is enabled for this protocol on // the route's NIC. - linkRes LinkAddressResolver + linkRes linkResolver } type routeInfo struct { @@ -184,11 +184,11 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA return r } - if r.linkRes == nil { + if r.linkRes.resolver == nil { return r } - if linkAddr, ok := r.linkRes.ResolveStaticAddress(r.RemoteAddress); ok { + if linkAddr, ok := r.linkRes.resolver.ResolveStaticAddress(r.RemoteAddress); ok { r.ResolveWith(linkAddr) return r } @@ -362,7 +362,7 @@ func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteIn } afterResolveFields := fields - linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, r.linkRes, func(r LinkResolutionResult) { + linkAddr, ch, err := r.linkRes.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, func(r LinkResolutionResult) { if afterResolve != nil { if r.Success { afterResolveFields.RemoteLinkAddress = r.LinkAddress @@ -400,7 +400,7 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isResolutionRequiredRLocked() bool { - return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local() + return len(r.mu.remoteLinkAddress) == 0 && r.linkRes.resolver != nil && r.isValidForOutgoingRLocked() && !r.local() } func (r *Route) isValidForOutgoing() bool { @@ -528,5 +528,7 @@ func (r *Route) IsOutboundBroadcast() bool { // "Reachable" is defined as having full-duplex communication between the // local and remote ends of the route. func (r *Route) ConfirmReachable() { - r.outgoingNIC.confirmReachable(r.nextHop()) + if r.linkRes.resolver != nil { + r.linkRes.confirmReachable(r.nextHop()) + } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 73db6e031..9390aaf57 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1560,7 +1560,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, } // Neighbors returns all IP to MAC address associations. -func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, tcpip.Error) { +func (s *Stack) Neighbors(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1569,11 +1569,11 @@ func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, tcpip.Error) { return nil, &tcpip.ErrUnknownNICID{} } - return nic.neighbors() + return nic.neighbors(protocol) } // AddStaticNeighbor statically associates an IP address to a MAC address. -func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { +func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1582,13 +1582,13 @@ func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAdd return &tcpip.ErrUnknownNICID{} } - return nic.addStaticNeighbor(addr, linkAddr) + return nic.addStaticNeighbor(addr, protocol, linkAddr) } // RemoveNeighbor removes an IP to MAC address association previously created // either automically or by AddStaticNeighbor. Returns ErrBadAddress if there // is no association with the provided address. -func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) tcpip.Error { +func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1597,11 +1597,11 @@ func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) tcpip.Erro return &tcpip.ErrUnknownNICID{} } - return nic.removeNeighbor(addr) + return nic.removeNeighbor(protocol, addr) } // ClearNeighbors removes all IP to MAC address associations. -func (s *Stack) ClearNeighbors(nicID tcpip.NICID) tcpip.Error { +func (s *Stack) ClearNeighbors(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1610,7 +1610,7 @@ func (s *Stack) ClearNeighbors(nicID tcpip.NICID) tcpip.Error { return &tcpip.ErrUnknownNICID{} } - return nic.clearNeighbors() + return nic.clearNeighbors(protocol) } // RegisterTransportEndpoint registers the given endpoint with the stack @@ -1980,7 +1980,7 @@ func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtoco } // NUDConfigurations gets the per-interface NUD configurations. -func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, tcpip.Error) { +func (s *Stack) NUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() @@ -1989,14 +1989,14 @@ func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, tcpip.Erro return NUDConfigurations{}, &tcpip.ErrUnknownNICID{} } - return nic.nudConfigs() + return nic.nudConfigs(proto) } // SetNUDConfigurations sets the per-interface NUD configurations. // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) tcpip.Error { +func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() @@ -2005,7 +2005,7 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) tcpip. return &tcpip.ErrUnknownNICID{} } - return nic.setNUDConfigs(c) + return nic.setNUDConfigs(proto, c) } // Seed returns a 32 bit value that can be used as a seed value for port diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index a166c0502..375cd3080 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -4313,9 +4314,11 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") ) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, UseNeighborCache: true, + Clock: clock, }) e := channel.New(0, 0, "") e.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -4323,36 +4326,56 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddStaticNeighbor(nicID, ipv4Addr, linkAddr); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv4Addr, linkAddr, err) - } - if err := s.AddStaticNeighbor(nicID, ipv6Addr, linkAddr); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv6Addr, linkAddr, err) + addrs := []struct { + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + }{ + { + proto: ipv4.ProtocolNumber, + addr: ipv4Addr, + }, + { + proto: ipv6.ProtocolNumber, + addr: ipv6Addr, + }, } - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) - } else if len(neighbors) != 2 { - t.Fatalf("got len(neighbors) = %d, want = 2; neighbors = %#v", len(neighbors), neighbors) + for _, addr := range addrs { + if err := s.AddStaticNeighbor(nicID, addr.proto, addr.addr, linkAddr); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, addr.proto, addr.addr, linkAddr, err) + } + + if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) + } else if diff := cmp.Diff( + []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + neighbors, + ); diff != "" { + t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff) + } } // Disabling the NIC should clear the neighbor table. if err := s.DisableNIC(nicID); err != nil { t.Fatalf("s.DisableNIC(%d): %s", nicID, err) } - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) - } else if len(neighbors) != 0 { - t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + for _, addr := range addrs { + if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) + } else if len(neighbors) != 0 { + t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) + } } // Enabling the NIC should have an empty neighbor table. if err := s.EnableNIC(nicID); err != nil { t.Fatalf("s.EnableNIC(%d): %s", nicID, err) } - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) - } else if len(neighbors) != 0 { - t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + for _, addr := range addrs { + if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) + } else if len(neighbors) != 0 { + t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) + } } } diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 7069352f2..b3a5d49d7 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -1069,9 +1069,9 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { // Wait for the remote's neighbor entry to be stale before creating a // TCP connection from host1 to some remote. - nudConfigs, err := host1Stack.NUDConfigurations(host1NICID) + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) if err != nil { - t.Fatalf("host1Stack.NUDConfigurations(%d): %s", host1NICID, err) + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) } // The maximum reachable time for a neighbor is some maximum random factor // applied to the base reachable time. |