diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/network/arp/arp_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 38 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache_test.go | 95 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache.go | 23 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache_test.go | 434 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry.go | 61 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry_test.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 150 | ||||
-rw-r--r-- | pkg/tcpip/stack/nud_test.go | 218 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 24 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 57 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/link_resolution_test.go | 4 |
21 files changed, 669 insertions, 566 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 5fd4c5574..0d7fadc31 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -148,7 +148,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { remoteAddr := tcpip.Address(h.ProtocolAddressSender()) remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.nic.HandleNeighborProbe(remoteAddr, remoteLinkAddr, e) + switch err := e.nic.HandleNeighborProbe(header.IPv4ProtocolNumber, remoteAddr, remoteLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ARP but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) + } respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, @@ -190,7 +196,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // The solicited, override, and isRouter flags are not available for ARP; // they are only available for IPv6 Neighbor Advertisements. - e.nic.HandleNeighborConfirmation(addr, linkAddr, stack.ReachabilityConfirmationFlags{ + switch err := e.nic.HandleNeighborConfirmation(header.IPv4ProtocolNumber, addr, linkAddr, stack.ReachabilityConfirmationFlags{ // Solicited and unsolicited (also referred to as gratuitous) ARP Replies // are handled equivalently to a solicited Neighbor Advertisement. Solicited: true, @@ -199,7 +205,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { Override: false, // ARP does not distinguish between router and non-router hosts. IsRouter: false, - }) + }); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ARP but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor confirmation message: %s", err)) + } } } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index d753a97af..24357e15d 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -491,9 +491,9 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { t.Fatal(err) } - neighbors, err := c.s.Neighbors(nicID) + neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber) if err != nil { - t.Fatalf("c.s.Neighbors(%d): %s", nicID, err) + t.Fatalf("c.s.Neighbors(%d, %d): %s", nicID, ipv4.ProtocolNumber, err) } neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 291330e8e..8d155344b 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -311,10 +311,12 @@ func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.N return &tcpip.ErrNotSupported{} } -func (*testInterface) HandleNeighborProbe(tcpip.Address, tcpip.LinkAddress, stack.LinkAddressResolver) { +func (*testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { + return nil } -func (*testInterface) HandleNeighborConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) { +func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { + return nil } func TestSourceAddressValidation(t *testing.T) { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index bdc88fe5d..12e5ead5e 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -290,7 +290,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { received.invalid.Increment() return } else { - e.nic.HandleNeighborProbe(srcAddr, sourceLinkAddr, e) + switch err := e.nic.HandleNeighborProbe(ProtocolNumber, srcAddr, sourceLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) + } } // As per RFC 4861 section 7.1.1: @@ -456,11 +462,17 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // If the NA message has the target link layer option, update the link // address cache with the link address for the target of the message. - e.nic.HandleNeighborConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ + switch err := e.nic.HandleNeighborConfirmation(ProtocolNumber, targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ Solicited: na.SolicitedFlag(), Override: na.OverrideFlag(), IsRouter: na.RouterFlag(), - }) + }); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor confirmation message: %s", err)) + } case header.ICMPv6EchoRequest: received.echoRequest.Increment() @@ -566,9 +578,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { return } - // A RS with a specified source IP address modifies the NUD state - // machine in the same way a reachability probe would. - e.nic.HandleNeighborProbe(srcAddr, sourceLinkAddr, e) + // A RS with a specified source IP address modifies the neighbor table + // in the same way a regular probe would. + switch err := e.nic.HandleNeighborProbe(ProtocolNumber, srcAddr, sourceLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) + } } case header.ICMPv6RouterAdvert: @@ -617,7 +635,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // If the RA has the source link layer option, update the link address // cache with the link address for the advertised router. if len(sourceLinkAddr) != 0 { - e.nic.HandleNeighborProbe(routerAddr, sourceLinkAddr, e) + switch err := e.nic.HandleNeighborProbe(ProtocolNumber, routerAddr, sourceLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) + } } e.mu.Lock() diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 755293377..4374d0198 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -139,12 +139,14 @@ func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gs return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) } -func (t *testInterface) HandleNeighborProbe(tcpip.Address, tcpip.LinkAddress, stack.LinkAddressResolver) { +func (t *testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { t.probeCount++ + return nil } -func (t *testInterface) HandleNeighborConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) { +func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { t.confirmationCount++ + return nil } func TestICMPCounts(t *testing.T) { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 4cc81e6cc..e0245487b 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -338,18 +338,18 @@ func TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *test Data: hdr.View().ToVectorisedView(), })) - neighbors, err := s.Neighbors(nicID) + neighbors, err := s.Neighbors(nicID, ProtocolNumber) if err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) } neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) for _, n := range neighbors { if existing, ok := neighborByAddr[n.Addr]; ok { if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff) } - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing) } neighborByAddr[n.Addr] = n } @@ -907,18 +907,18 @@ func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *tes Data: hdr.View().ToVectorisedView(), })) - neighbors, err := s.Neighbors(nicID) + neighbors, err := s.Neighbors(nicID, ProtocolNumber) if err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) } neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) for _, n := range neighbors { if existing, ok := neighborByAddr[n.Addr]; ok { if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff) } - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing) } neighborByAddr[n.Addr] = n } @@ -1277,8 +1277,8 @@ func TestNeighborAdvertisementValidation(t *testing.T) { // There is no need to create an entry if none exists, since the // recipient has apparently not initiated any communication with the // target. - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) + if neighbors, err := s.Neighbors(nicID, ProtocolNumber); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) } else if len(neighbors) != 0 { t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 704812641..c24f56ece 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -400,7 +400,10 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC if !ok { t.Fatal("NIC 2 does not exist") } - proto.neighborTable = nic.neighborTable + + if l, ok := nic.linkAddrResolvers[fwdTestNetNumber]; ok { + proto.neighborTable = l.neighborTable + } // Route all packets to NIC 2. { diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 4504db752..5b6b58b1d 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -32,6 +32,8 @@ const linkAddrCacheSize = 512 // max cache entries type linkAddrCache struct { nic *NIC + linkRes LinkAddressResolver + // ageLimit is how long a cache entry is valid for. ageLimit time.Duration @@ -196,10 +198,10 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { return entry } -// get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { +// get reports any known link address for addr. +func (c *linkAddrCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { c.mu.Lock() - entry := c.getOrCreateEntryLocked(k) + entry := c.getOrCreateEntryLocked(addr) entry.mu.Lock() defer entry.mu.Unlock() c.mu.Unlock() @@ -222,7 +224,7 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA } if entry.mu.done == nil { entry.mu.done = make(chan struct{}) - go c.startAddressResolution(k, linkRes, localAddr, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + go c.startAddressResolution(addr, localAddr, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{} default: @@ -230,11 +232,11 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA } } -func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.Address, localAddr tcpip.Address, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */) + c.linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */) select { case now := <-time.After(c.resolutionTimeout): @@ -278,15 +280,18 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt return true } -func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { - c := &linkAddrCache{ +func (c *linkAddrCache) init(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int, linkRes LinkAddressResolver) { + *c = linkAddrCache{ nic: nic, + linkRes: linkRes, ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, } + + c.mu.Lock() c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize) - return c + c.mu.Unlock() } var _ neighborTable = (*linkAddrCache)(nil) @@ -307,7 +312,7 @@ func (*linkAddrCache) removeAll() tcpip.Error { return &tcpip.ErrNotSupported{} } -func (c *linkAddrCache) handleProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress, _ LinkAddressResolver) { +func (c *linkAddrCache) handleProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress) { if len(linkAddr) != 0 { // NUD allows probes without a link address but linkAddrCache // is a simple neighbor table which does not implement NUD. diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 4df6f9265..9e7f331c9 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -77,10 +77,10 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe return 1 } -func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolver) (tcpip.LinkAddress, tcpip.Error) { +func getBlocking(c *linkAddrCache, addr tcpip.Address) (tcpip.LinkAddress, tcpip.Error) { var attemptedResolution bool for { - got, ch, err := c.get(addr, linkRes, "", nil) + got, ch, err := c.get(addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); ok { if attemptedResolution { return got, &tcpip.ErrTimeout{} @@ -100,27 +100,28 @@ func newEmptyNIC() *NIC { } func TestCacheOverflow(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) + var c linkAddrCache + c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) for i := len(testAddrs) - 1; i >= 0; i-- { e := testAddrs[i] c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("insert %d, c.get(%s, nil, '', nil): %s", i, e.addr, err) + t.Errorf("insert %d, c.get(%s, '', nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("insert %d, got c.get(%s, nil, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) + t.Errorf("insert %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // Expect to find at least half of the most recent entries. for i := 0; i < linkAddrCacheSize/2; i++ { e := testAddrs[i] - got, _, err := c.get(e.addr, nil, "", nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("check %d, c.get(%s, nil, '', nil): %s", i, e.addr, err) + t.Errorf("check %d, c.get(%s, '', nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("check %d, got c.get(%s, nil, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) + t.Errorf("check %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // The earliest entries should no longer be in the cache. @@ -135,8 +136,9 @@ func TestCacheOverflow(t *testing.T) { } func TestCacheConcurrent(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, linkRes) var wg sync.WaitGroup for r := 0; r < 16; r++ { @@ -154,12 +156,12 @@ func TestCacheConcurrent(t *testing.T) { // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, linkRes, "", nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, _, '', nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, _, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } e = testAddrs[0] @@ -171,38 +173,40 @@ func TestCacheConcurrent(t *testing.T) { } func TestCacheAgeLimit(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3, linkRes) e := testAddrs[0] c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - _, _, err := c.get(e.addr, linkRes, "", nil) + _, _, err := c.get(e.addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.get(%s, _, '', nil) = %s, want = ErrWouldBlock", e.addr, err) + t.Errorf("got c.get(%s, '', nil) = %s, want = ErrWouldBlock", e.addr, err) } } func TestCacheReplace(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) + var c linkAddrCache + c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) e := testAddrs[0] l2 := e.linkAddr + "2" c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, nil, '', nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, nil, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } c.add(e.addr, l2) - got, _, err = c.get(e.addr, nil, "", nil) + got, _, err = c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, nil, '', nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != l2 { - t.Errorf("got c.get(%s, nil, '', nil) = %s, want = %s", e.addr, got, l2) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, l2) } } @@ -213,34 +217,36 @@ func TestCacheResolution(t *testing.T) { // // Using a large resolution timeout decreases the probability of experiencing // this race condition and does not affect how long this test takes to run. - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1, linkRes) for i, ta := range testAddrs { - got, err := getBlocking(c, ta.addr, linkRes) + got, err := getBlocking(&c, ta.addr) if err != nil { - t.Errorf("check %d, getBlocking(_, %s, _): %s", i, ta.addr, err) + t.Errorf("check %d, getBlocking(_, %s): %s", i, ta.addr, err) } if got != ta.linkAddr { - t.Errorf("check %d, got getBlocking(_, %s, _) = %s, want = %s", i, ta.addr, got, ta.linkAddr) + t.Errorf("check %d, got getBlocking(_, %s) = %s, want = %s", i, ta.addr, got, ta.linkAddr) } } // Check that after resolved, address stays in the cache and never returns WouldBlock. for i := 0; i < 10; i++ { e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, linkRes, "", nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, _, '', nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, _, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } } } func TestCacheResolutionFailed(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5, linkRes) var requestCount uint32 linkRes.onLinkAddressRequest = func() { @@ -249,20 +255,20 @@ func TestCacheResolutionFailed(t *testing.T) { // First, sanity check that resolution is working... e := testAddrs[0] - got, err := getBlocking(c, e.addr, linkRes) + got, err := getBlocking(&c, e.addr) if err != nil { - t.Errorf("getBlocking(_, %s, _): %s", e.addr, err) + t.Errorf("getBlocking(_, %s): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got getBlocking(_, %s, _) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got getBlocking(_, %s) = %s, want = %s", e.addr, got, e.linkAddr) } before := atomic.LoadUint32(&requestCount) e.addr += "2" - a, err := getBlocking(c, e.addr, linkRes) + a, err := getBlocking(&c, e.addr) if _, ok := err.(*tcpip.ErrTimeout); !ok { - t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) + t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { @@ -273,12 +279,13 @@ func TestCacheResolutionFailed(t *testing.T) { func TestCacheResolutionTimeout(t *testing.T) { resolverDelay := 500 * time.Millisecond expiration := resolverDelay / 10 - c := newLinkAddrCache(newEmptyNIC(), expiration, 1*time.Millisecond, 3) - linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c, delay: resolverDelay} + c.init(newEmptyNIC(), expiration, 1*time.Millisecond, 3, linkRes) e := testAddrs[0] - a, err := getBlocking(c, e.addr, linkRes) + a, err := getBlocking(&c, e.addr) if _, ok := err.(*tcpip.ErrTimeout); !ok { - t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) + t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index c13be137e..0238605af 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -2796,8 +2796,8 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN NIC: nicID, }}) - if err := s.AddStaticNeighbor(nicID, llAddr3, linkAddr3); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) + if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err) } return ndpDisp, e, s } diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 64b8046f5..7e3132058 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -43,8 +43,9 @@ type NeighborStats struct { // Their state is always Static. The amount of static entries stored in the // cache is unbounded. type neighborCache struct { - nic *NIC - state *NUDState + nic *NIC + state *NUDState + linkRes LinkAddressResolver // mu protects the fields below. mu sync.RWMutex @@ -69,7 +70,7 @@ type neighborCache struct { // reset to state incomplete, and returned. If no matching entry exists and the // cache is not full, a new entry with state incomplete is allocated and // returned. -func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { +func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address) *neighborEntry { n.mu.Lock() defer n.mu.Unlock() @@ -85,7 +86,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA // The entry that needs to be created must be dynamic since all static // entries are directly added to the cache via addStaticEntry. - entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes) + entry := newNeighborEntry(n, remoteAddr, n.state) if n.dynamic.count == neighborCacheSize { e := n.dynamic.lru.Back() e.mu.Lock() @@ -122,8 +123,8 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA // packet prompting NUD/link address resolution. // // TODO(gvisor.dev/issue/5151): Don't return the neighbor entry. -func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) { - entry := n.getOrCreateEntry(remoteAddr, linkRes) +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) { + entry := n.getOrCreateEntry(remoteAddr) entry.mu.Lock() defer entry.mu.Unlock() @@ -202,7 +203,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd entry.mu.Unlock() } - n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) + n.cache[addr] = newStaticNeighborEntry(n, addr, linkAddr, n.state) } // removeEntry removes a dynamic or static entry by address from the neighbor @@ -265,8 +266,8 @@ func (n *neighborCache) neighbors() ([]NeighborEntry, tcpip.Error) { return n.entries(), nil } -func (n *neighborCache) get(addr tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - entry, ch, err := n.entry(addr, localAddr, linkRes, onResolve) +func (n *neighborCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { + entry, ch, err := n.entry(addr, localAddr, onResolve) return entry.LinkAddr, ch, err } @@ -286,8 +287,8 @@ func (n *neighborCache) removeAll() tcpip.Error { // handleProbe handles a neighbor probe as defined by RFC 4861 section 7.2.3. // // Validation of the probe is expected to be handled by the caller. -func (n *neighborCache) handleProbe(remoteAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { - entry := n.getOrCreateEntry(remoteAddr, linkRes) +func (n *neighborCache) handleProbe(remoteAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + entry := n.getOrCreateEntry(remoteAddr) entry.mu.Lock() entry.handleProbeLocked(remoteLinkAddr) entry.mu.Unlock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 122888fcf..b489b5e08 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -76,10 +76,15 @@ func entryDiffOptsWithSort() []cmp.Option { })) } -func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { +func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *testNeighborResolver { config.resetInvalidFields() rng := rand.New(rand.NewSource(time.Now().UnixNano())) - neigh := &neighborCache{ + linkRes := &testNeighborResolver{ + clock: clock, + entries: newTestEntryStore(), + delay: typicalLatency, + } + linkRes.neigh = &neighborCache{ nic: &NIC{ stack: &Stack{ clock: clock, @@ -88,10 +93,11 @@ func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock id: 1, stats: makeNICStats(), }, - state: NewNUDState(config, rng), - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + state: NewNUDState(config, rng), + linkRes: linkRes, + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), } - return neigh + return linkRes } // testEntryStore contains a set of IP to NeighborEntry mappings. @@ -241,10 +247,10 @@ func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) + linkRes := newTestNeighborResolver(&nudDisp, c, clock) - if got, want := neigh.config(), c; got != want { - t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + if got, want := linkRes.neigh.config(), c; got != want { + t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) } // No events should have been dispatched. @@ -259,14 +265,14 @@ func TestNeighborCacheSetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) + linkRes := newTestNeighborResolver(&nudDisp, c, clock) c.MinRandomFactor = 1 c.MaxRandomFactor = 1 - neigh.setConfig(c) + linkRes.neigh.setConfig(c) - if got, want := neigh.config(), c; got != want { - t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + if got, want := linkRes.neigh.config(), c; got != want { + t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) } // No events should have been dispatched. @@ -281,22 +287,15 @@ func TestNeighborCacheEntry(t *testing.T) { c := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, c, clock) - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -328,8 +327,8 @@ func TestNeighborCacheEntry(t *testing.T) { t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + if _, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err != nil { + t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } // No more events should have been dispatched. @@ -345,23 +344,16 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -393,7 +385,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } - neigh.removeEntry(entry.Addr) + linkRes.neigh.removeEntry(entry.Addr) { wantEvents := []testEntryEventInfo{ @@ -416,17 +408,15 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } { - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } } } type testContext struct { clock *faketime.ManualClock - neigh *neighborCache - store *testEntryStore linkRes *testNeighborResolver nudDisp *testNUDDispatcher } @@ -434,19 +424,10 @@ type testContext struct { func newTestContext(c NUDConfigurations) testContext { nudDisp := &testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(nudDisp, c, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(nudDisp, c, clock) return testContext{ clock: clock, - neigh: neigh, - store: store, linkRes: linkRes, nudDisp: nudDisp, } @@ -460,17 +441,17 @@ type overflowOptions struct { func (c *testContext) overflowCache(opts overflowOptions) error { // Fill the neighbor cache to capacity to verify the LRU eviction strategy is // working properly after the entry removal. - for i := opts.startAtEntryIndex; i < c.store.size(); i++ { + for i := opts.startAtEntryIndex; i < c.linkRes.entries.size(); i++ { // Add a new entry - entry, ok := c.store.entry(i) + entry, ok := c.linkRes.entries.entry(i) if !ok { - return fmt.Errorf("c.store.entry(%d) not found", i) + return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + return fmt.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } - c.clock.Advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) var wantEvents []testEntryEventInfo @@ -478,9 +459,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error { // LRU eviction strategy. Note that the number of static entries should not // affect the total number of dynamic entries that can be added. if i >= neighborCacheSize+opts.startAtEntryIndex { - removedEntry, ok := c.store.entry(i - neighborCacheSize) + removedEntry, ok := c.linkRes.entries.entry(i - neighborCacheSize) if !ok { - return fmt.Errorf("store.entry(%d) not found", i-neighborCacheSize) + return fmt.Errorf("linkRes.entries.entry(%d) not found", i-neighborCacheSize) } wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestRemoved, @@ -523,10 +504,10 @@ func (c *testContext) overflowCache(opts overflowOptions) error { // by entries() is nondeterministic, so entries have to be sorted before // comparison. wantUnsortedEntries := opts.wantStaticEntries - for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ { - entry, ok := c.store.entry(i) + for i := c.linkRes.entries.size() - neighborCacheSize; i < c.linkRes.entries.size(); i++ { + entry, ok := c.linkRes.entries.entry(i) if !ok { - return fmt.Errorf("c.store.entry(%d) not found", i) + return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) } wantEntry := NeighborEntry{ Addr: entry.Addr, @@ -536,7 +517,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(wantUnsortedEntries, c.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, c.linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } @@ -580,15 +561,15 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { c := newTestContext(config) // Add a dynamic entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } - c.clock.Advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -617,7 +598,7 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { } // Remove the entry - c.neigh.removeEntry(entry.Addr) + c.linkRes.neigh.removeEntry(entry.Addr) { wantEvents := []testEntryEventInfo{ @@ -656,12 +637,12 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { c := newTestContext(config) // Add a static entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -682,7 +663,7 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { } // Remove the static entry that was just added - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) // No more events should have been dispatched. c.nudDisp.mu.Lock() @@ -700,12 +681,12 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) c := newTestContext(config) // Add a static entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -727,7 +708,7 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) // Add a duplicate entry with a different link address staticLinkAddr += "duplicate" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) { wantEvents := []testEntryEventInfo{ { @@ -762,12 +743,12 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { c := newTestContext(config) // Add a static entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -788,7 +769,7 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { } // Remove the static entry that was just added - c.neigh.removeEntry(entry.Addr) + c.linkRes.neigh.removeEntry(entry.Addr) { wantEvents := []testEntryEventInfo{ { @@ -832,13 +813,13 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { c := newTestContext(config) // Add a dynamic entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -870,7 +851,7 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { // Override the entry with a static one using the same address staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) { wantEvents := []testEntryEventInfo{ { @@ -925,14 +906,14 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { c := newTestContext(config) - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) - e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + c.linkRes.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) + e, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -940,7 +921,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { State: Static, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ @@ -982,23 +963,16 @@ func TestNeighborCacheClear(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) // Add a dynamic entry. - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -1030,7 +1004,7 @@ func TestNeighborCacheClear(t *testing.T) { } // Add a static entry. - neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) + linkRes.neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) { wantEvents := []testEntryEventInfo{ @@ -1054,7 +1028,7 @@ func TestNeighborCacheClear(t *testing.T) { } // Clear should remove both dynamic and static entries. - neigh.clear() + linkRes.neigh.clear() // Remove events dispatched from clear() have no deterministic order so they // need to be sorted beforehand. @@ -1098,13 +1072,13 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { c := newTestContext(config) // Add a dynamic entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -1135,7 +1109,7 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { } // Clear the cache. - c.neigh.clear() + c.linkRes.neigh.clear() { wantEvents := []testEntryEventInfo{ { @@ -1174,18 +1148,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) - frequentlyUsedEntry, ok := store.entry(0) + frequentlyUsedEntry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } // The following logic is very similar to overflowCache, but @@ -1193,23 +1160,23 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // Fill the neighbor cache to capacity for i := 0; i < neighborCacheSize; i++ { - entry, ok := store.entry(i) + entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("store.entry(%d) not found", i) + t.Fatalf("linkRes.entries.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { @@ -1240,38 +1207,38 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Keep adding more entries - for i := neighborCacheSize; i < store.size(); i++ { + for i := neighborCacheSize; i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { - if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", frequentlyUsedEntry.Addr, err) + if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { + t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", frequentlyUsedEntry.Addr, err) } } - entry, ok := store.entry(i) + entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("store.entry(%d) not found", i) + t.Fatalf("linkRes.entries.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy - removedEntry, ok := store.entry(i - neighborCacheSize + 1) + removedEntry, ok := linkRes.entries.entry(i - neighborCacheSize + 1) if !ok { - t.Fatalf("store.entry(%d) not found", i-neighborCacheSize+1) + t.Fatalf("linkRes.entries.entry(%d) not found", i-neighborCacheSize+1) } wantEvents := []testEntryEventInfo{ { @@ -1321,10 +1288,10 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { }, } - for i := store.size() - neighborCacheSize + 1; i < store.size(); i++ { - entry, ok := store.entry(i) + for i := linkRes.entries.size() - neighborCacheSize + 1; i < linkRes.entries.size(); i++ { + entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("store.entry(%d) not found", i) + t.Fatalf("linkRes.entries.entry(%d) not found", i) } wantEntry := NeighborEntry{ Addr: entry.Addr, @@ -1334,7 +1301,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } @@ -1353,26 +1320,19 @@ func TestNeighborCacheConcurrent(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) - storeEntries := store.entries() + storeEntries := linkRes.entries.entries() for _, entry := range storeEntries { var wg sync.WaitGroup for r := 0; r < concurrentProcesses; r++ { wg.Add(1) go func(entry NeighborEntry) { defer wg.Done() - switch e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err.(type) { + switch e, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err.(type) { case nil, *tcpip.ErrWouldBlock: default: - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{}) } }(entry) } @@ -1390,10 +1350,10 @@ func TestNeighborCacheConcurrent(t *testing.T) { // The order of entries reported by entries() is nondeterministic, so entries // have to be sorted before comparison. var wantUnsortedEntries []NeighborEntry - for i := store.size() - neighborCacheSize; i < store.size(); i++ { - entry, ok := store.entry(i) + for i := linkRes.entries.size() - neighborCacheSize; i < linkRes.entries.size(); i++ { + entry, ok := linkRes.entries.entry(i) if !ok { - t.Errorf("store.entry(%d) not found", i) + t.Errorf("linkRes.entries.entry(%d) not found", i) } wantEntry := NeighborEntry{ Addr: entry.Addr, @@ -1403,7 +1363,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } } @@ -1413,41 +1373,34 @@ func TestNeighborCacheReplace(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) // Add an entry - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } // Verify the entry exists { - e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) } if t.Failed() { t.FailNow() @@ -1458,21 +1411,21 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } // Notify of a link address change var updatedLinkAddr tcpip.LinkAddress { - entry, ok := store.entry(1) + entry, ok := linkRes.entries.entry(1) if !ok { - t.Fatal("store.entry(1) not found") + t.Fatal("linkRes.entries.entry(1) not found") } updatedLinkAddr = entry.LinkAddr } - store.set(0, updatedLinkAddr) - neigh.handleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ + linkRes.entries.set(0, updatedLinkAddr) + linkRes.neigh.handleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, @@ -1482,9 +1435,9 @@ func TestNeighborCacheReplace(t *testing.T) { // // Verify the entry's new link address and the new state. { - e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Fatalf("neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1492,17 +1445,17 @@ func TestNeighborCacheReplace(t *testing.T) { State: Delay, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } // Verify that the neighbor is now reachable. { - e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1510,7 +1463,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } } @@ -1520,46 +1473,39 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() + linkRes := newTestNeighborResolver(&nudDisp, config, clock) var requestCount uint32 - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - onLinkAddressRequest: func() { - atomic.AddUint32(&requestCount, 1) - }, + linkRes.onLinkAddressRequest = func() { + atomic.AddUint32(&requestCount, 1) } - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } // First, sanity check that resolution is working { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } - got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + got, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1567,7 +1513,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { State: Reachable, } if diff := cmp.Diff(want, got, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } // Verify address resolution fails for an unknown address. @@ -1575,24 +1521,24 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { entry.Addr += "2" { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } - maxAttempts := neigh.config().MaxUnicastProbes + maxAttempts := linkRes.neigh.config().MaxUnicastProbes if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want { t.Errorf("got link address request count = %d, want = %d", got, want) } @@ -1606,27 +1552,22 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { config.RetransmitTimer = time.Millisecond // small enough to cause timeout clock := faketime.NewManualClock() - neigh := newTestNeighborCache(nil, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: time.Minute, // large enough to cause timeout - } + linkRes := newTestNeighborResolver(nil, config, clock) + // large enough to cause timeout + linkRes.delay = time.Minute - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1634,7 +1575,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } @@ -1643,31 +1584,24 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { func TestNeighborCacheRetryResolution(t *testing.T) { config := DefaultNUDConfigurations() clock := faketime.NewManualClock() - neigh := newTestNeighborCache(nil, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - // Simulate a faulty link. - dropReplies: true, - } + linkRes := newTestNeighborResolver(nil, config, clock) + // Simulate a faulty link. + linkRes.dropReplies = true - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } // Perform address resolution with a faulty link, which will fail. { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1675,7 +1609,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) { select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } @@ -1687,20 +1621,20 @@ func TestNeighborCacheRetryResolution(t *testing.T) { State: Failed, }, } - if diff := cmp.Diff(neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) } // Retry address resolution with a working link. linkRes.dropReplies = false { - incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + incompleteEntry, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } if incompleteEntry.State != Incomplete { t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) @@ -1712,9 +1646,9 @@ func TestNeighborCacheRetryResolution(t *testing.T) { if !ok { t.Fatal("expected successful address resolution") } - reachableEntry, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + reachableEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Fatalf("neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) + t.Fatalf("linkRes.neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) } if reachableEntry.Addr != entry.Addr { t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr) @@ -1726,7 +1660,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) { t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String()) } default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } } @@ -1735,42 +1669,36 @@ func BenchmarkCacheClear(b *testing.B) { b.StopTimer() config := DefaultNUDConfigurations() clock := &tcpip.StdClock{} - neigh := newTestNeighborCache(nil, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: 0, - } + linkRes := newTestNeighborResolver(nil, config, clock) + linkRes.delay = 0 // Clear for every possible size of the cache for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { // Fill the neighbor cache to capacity. for i := 0; i < cacheSize; i++ { - entry, ok := store.entry(i) + entry, ok := linkRes.entries.entry(i) if !ok { - b.Fatalf("store.entry(%d) not found", i) + b.Fatalf("linkRes.entries.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + b.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } select { case <-ch: default: - b.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } b.StartTimer() - neigh.clear() + linkRes.neigh.clear() b.StopTimer() } } diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index a037ca6f9..b05f96d4f 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -77,11 +77,7 @@ const ( type neighborEntry struct { neighborEntryEntry - nic *NIC - - // linkRes provides the functionality to send reachability probes, used in - // Neighbor Unreachability Detection. - linkRes LinkAddressResolver + cache *neighborCache // nudState points to the Neighbor Unreachability Detection configuration. nudState *NUDState @@ -106,10 +102,9 @@ type neighborEntry struct { // state, Unknown. Transition out of Unknown by calling either // `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created // neighborEntry. -func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { +func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState *NUDState) *neighborEntry { return &neighborEntry{ - nic: nic, - linkRes: linkRes, + cache: cache, nudState: nudState, neigh: NeighborEntry{ Addr: remoteAddr, @@ -121,18 +116,18 @@ func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, li // newStaticNeighborEntry creates a neighbor cache entry starting at the // Static state. The entry can only transition out of Static by directly // calling `setStateLocked`. -func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { +func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { entry := NeighborEntry{ Addr: addr, LinkAddr: linkAddr, State: Static, - UpdatedAtNanos: nic.stack.clock.NowNanoseconds(), + UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(), } - if nic.stack.nudDisp != nil { - nic.stack.nudDisp.OnNeighborAdded(nic.id, entry) + if nudDisp := cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborAdded(cache.nic.id, entry) } return &neighborEntry{ - nic: nic, + cache: cache, nudState: state, neigh: entry, } @@ -158,7 +153,7 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { // is resolved (which ends up obtaining the entry's lock) while holding the // link resolution queue's lock. Dequeuing packets in a new goroutine avoids // a lock ordering violation. - go e.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) + go e.cache.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) } } @@ -167,8 +162,8 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchAddEventLocked() { - if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborAdded(e.nic.id, e.neigh) + if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborAdded(e.cache.nic.id, e.neigh) } } @@ -177,8 +172,8 @@ func (e *neighborEntry) dispatchAddEventLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchChangeEventLocked() { - if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborChanged(e.nic.id, e.neigh) + if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborChanged(e.cache.nic.id, e.neigh) } } @@ -187,8 +182,8 @@ func (e *neighborEntry) dispatchChangeEventLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchRemoveEventLocked() { - if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) + if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborRemoved(e.cache.nic.id, e.neigh) } } @@ -206,7 +201,7 @@ func (e *neighborEntry) cancelJobLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) removeLocked() { - e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() e.dispatchRemoveEventLocked() e.cancelJobLocked() e.notifyCompletionLocked(false /* succeeded */) @@ -222,7 +217,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { prev := e.neigh.State e.neigh.State = next - e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() config := e.nudState.Config() switch next { @@ -230,14 +225,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev)) case Reachable: - e.job = e.nic.stack.newJob(&e.mu, func() { + e.job = e.cache.nic.stack.newJob(&e.mu, func() { e.setStateLocked(Stale) e.dispatchChangeEventLocked() }) e.job.Schedule(e.nudState.ReachableTime()) case Delay: - e.job = e.nic.stack.newJob(&e.mu, func() { + e.job = e.cache.nic.stack.newJob(&e.mu, func() { e.setStateLocked(Probe) e.dispatchChangeEventLocked() }) @@ -254,14 +249,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr); err != nil { + if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return } retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(config.RetransmitTimer) } @@ -269,7 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // for finishing the state transition. This is necessary to avoid // deadlock where sending and processing probes are done synchronously, // such as loopback and integration tests. - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(immediateDuration) case Failed: @@ -292,12 +287,12 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { case Failed: - e.nic.stats.Neighbor.FailedEntryLookups.Increment() + e.cache.nic.stats.Neighbor.FailedEntryLookups.Increment() fallthrough case Unknown: e.neigh.State = Incomplete - e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() e.dispatchAddEventLocked() @@ -340,7 +335,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // address SHOULD be placed in the IP Source Address of the outgoing // solicitation. // - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, ""); err != nil { + if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, ""); err != nil { // There is no need to log the error here; the NUD implementation may // assume a working link. A valid link should be the responsibility of // the NIC/stack.LinkEndpoint. @@ -350,7 +345,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { } retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe) e.job.Schedule(config.RetransmitTimer) } @@ -358,7 +353,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // for finishing the state transition. This is necessary to avoid // deadlock where sending and processing probes are done synchronously, // such as loopback and integration tests. - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe) e.job.Schedule(immediateDuration) case Stale: @@ -504,7 +499,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // // TODO(gvisor.dev/issue/4085): Remove the special casing we do for IPv6 // here. - ep, ok := e.nic.networkEndpoints[header.IPv6ProtocolNumber] + ep, ok := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber] if !ok { panic(fmt.Sprintf("have a neighbor entry for an IPv6 router but no IPv6 network endpoint")) } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 5e5e0e6ca..57cfbdb8b 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -230,23 +230,30 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e }, stats: makeNICStats(), } + netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ - header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil), + header.IPv6ProtocolNumber: netEP, } rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) - linkRes := entryTestLinkResolver{} - entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes) - + var linkRes entryTestLinkResolver // Stub out the neighbor cache to verify deletion from the cache. neigh := &neighborCache{ - nic: &nic, - state: nudState, - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + nic: &nic, + state: nudState, + linkRes: &linkRes, + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + l := linkResolver{ + resolver: &linkRes, + neighborTable: neigh, } + entry := newNeighborEntry(neigh, entryTestAddr1 /* remoteAddr */, nudState) neigh.cache[entryTestAddr1] = entry - nic.neighborTable = neigh + nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]linkResolver{ + header.IPv6ProtocolNumber: l, + } return entry, &disp, &linkRes, clock } @@ -836,7 +843,7 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) - ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) + ipv6EP := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index c813b0da5..693ea064a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -27,11 +27,11 @@ import ( type neighborTable interface { neighbors() ([]NeighborEntry, tcpip.Error) addStaticEntry(tcpip.Address, tcpip.LinkAddress) - get(addr tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) + get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) remove(tcpip.Address) tcpip.Error removeAll() tcpip.Error - handleProbe(tcpip.Address, tcpip.LinkAddress, LinkAddressResolver) + handleProbe(tcpip.Address, tcpip.LinkAddress) handleConfirmation(tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) handleUpperLevelConfirmation(tcpip.Address) @@ -41,6 +41,20 @@ type neighborTable interface { var _ NetworkInterface = (*NIC)(nil) +type linkResolver struct { + resolver LinkAddressResolver + + neighborTable neighborTable +} + +func (l *linkResolver) getNeighborLinkAddress(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { + return l.neighborTable.get(addr, localAddr, onResolve) +} + +func (l *linkResolver) confirmReachable(addr tcpip.Address) { + l.neighborTable.handleUpperLevelConfirmation(addr) +} + // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { @@ -56,7 +70,7 @@ type NIC struct { // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint - linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + linkAddrResolvers map[tcpip.NetworkProtocolNumber]linkResolver // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. // @@ -67,8 +81,6 @@ type NIC struct { // complete. linkResQueue packetsPendingLinkResolution - neighborTable neighborTable - mu struct { sync.RWMutex spoofing bool @@ -153,25 +165,13 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC context: ctx, stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]linkResolver), } nic.linkResQueue.init(nic) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0 - if resolutionRequired { - if stack.useNeighborCache { - nic.neighborTable = &neighborCache{ - nic: nic, - state: NewNUDState(stack.nudConfigs, stack.randomGenerator), - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), - } - } else { - nic.neighborTable = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts) - } - } - // Register supported packet and network endpoint protocols. for _, netProto := range header.Ethertypes { nic.mu.packetEPs[netProto] = new(packetEndpointList) @@ -185,7 +185,24 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC if resolutionRequired { if r, ok := netEP.(LinkAddressResolver); ok { - nic.linkAddrResolvers[r.LinkAddressProtocol()] = r + l := linkResolver{ + resolver: r, + } + + if stack.useNeighborCache { + l.neighborTable = &neighborCache{ + nic: nic, + state: NewNUDState(stack.nudConfigs, stack.randomGenerator), + linkRes: r, + + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + } else { + cache := new(linkAddrCache) + cache.init(nic, ageLimit, resolutionTimeout, resolutionAttempts, r) + l.neighborTable = cache + } + nic.linkAddrResolvers[r.LinkAddressProtocol()] = l } } } @@ -240,18 +257,19 @@ func (n *NIC) disableLocked() { for _, ep := range n.networkEndpoints { ep.Disable() - } - // Clear the neighbour table (including static entries) as we cannot guarantee - // that the current neighbour table will be valid when the NIC is enabled - // again. - // - // This matches linux's behaviour at the time of writing: - // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371 - switch err := n.clearNeighbors(); err.(type) { - case nil, *tcpip.ErrNotSupported: - default: - panic(fmt.Sprintf("n.clearNeighbors(): %s", err)) + // Clear the neighbour table (including static entries) as we cannot + // guarantee that the current neighbour table will be valid when the NIC is + // enabled again. + // + // This matches linux's behaviour at the time of writing: + // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371 + netProto := ep.NetworkProtocolNumber() + switch err := n.clearNeighbors(netProto); err.(type) { + case nil, *tcpip.ErrNotSupported: + default: + panic(fmt.Sprintf("n.clearNeighbors(%d): %s", netProto, err)) + } } if !n.setEnabled(false) { @@ -604,63 +622,49 @@ func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error { return &tcpip.ErrBadLocalAddress{} } -func (n *NIC) confirmReachable(addr tcpip.Address) { - if n.neighborTable != nil { - n.neighborTable.handleUpperLevelConfirmation(addr) - } -} - func (n *NIC) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error { linkRes, ok := n.linkAddrResolvers[protocol] if !ok { return &tcpip.ErrNotSupported{} } - if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok { + if linkAddr, ok := linkRes.resolver.ResolveStaticAddress(addr); ok { onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true}) return nil } - _, _, err := n.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) + _, _, err := linkRes.getNeighborLinkAddress(addr, localAddr, onResolve) return err } -func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - if n.neighborTable != nil { - return n.neighborTable.get(addr, linkRes, localAddr, onResolve) - } - - return "", nil, &tcpip.ErrNotSupported{} -} - -func (n *NIC) neighbors() ([]NeighborEntry, tcpip.Error) { - if n.neighborTable != nil { - return n.neighborTable.neighbors() +func (n *NIC) neighbors(protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.neighbors() } return nil, &tcpip.ErrNotSupported{} } -func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) tcpip.Error { - if n.neighborTable != nil { - n.neighborTable.addStaticEntry(addr, linkAddress) +func (n *NIC) addStaticNeighbor(addr tcpip.Address, protocol tcpip.NetworkProtocolNumber, linkAddress tcpip.LinkAddress) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + linkRes.neighborTable.addStaticEntry(addr, linkAddress) return nil } return &tcpip.ErrNotSupported{} } -func (n *NIC) removeNeighbor(addr tcpip.Address) tcpip.Error { - if n.neighborTable != nil { - return n.neighborTable.remove(addr) +func (n *NIC) removeNeighbor(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.remove(addr) } return &tcpip.ErrNotSupported{} } -func (n *NIC) clearNeighbors() tcpip.Error { - if n.neighborTable != nil { - return n.neighborTable.removeAll() +func (n *NIC) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.removeAll() } return &tcpip.ErrNotSupported{} @@ -947,9 +951,9 @@ func (n *NIC) Name() string { } // nudConfigs gets the NUD configurations for n. -func (n *NIC) nudConfigs() (NUDConfigurations, tcpip.Error) { - if n.neighborTable != nil { - return n.neighborTable.nudConfig() +func (n *NIC) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.nudConfig() } return NUDConfigurations{}, &tcpip.ErrNotSupported{} @@ -959,10 +963,10 @@ func (n *NIC) nudConfigs() (NUDConfigurations, tcpip.Error) { // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (n *NIC) setNUDConfigs(c NUDConfigurations) tcpip.Error { - if n.neighborTable != nil { +func (n *NIC) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { c.resetInvalidFields() - return n.neighborTable.setNUDConfig(c) + return linkRes.neighborTable.setNUDConfig(c) } return &tcpip.ErrNotSupported{} @@ -1003,15 +1007,21 @@ func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { } // HandleNeighborProbe implements NetworkInterface. -func (n *NIC) HandleNeighborProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { - if n.neighborTable != nil { - n.neighborTable.handleProbe(addr, linkAddr, linkRes) +func (n *NIC) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { + if l, ok := n.linkAddrResolvers[protocol]; ok { + l.neighborTable.handleProbe(addr, linkAddr) + return nil } + + return &tcpip.ErrNotSupported{} } // HandleNeighborConfirmation implements NetworkInterface. -func (n *NIC) HandleNeighborConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { - if n.neighborTable != nil { - n.neighborTable.handleConfirmation(addr, linkAddr, flags) +func (n *NIC) HandleNeighborConfirmation(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) tcpip.Error { + if l, ok := n.linkAddrResolvers[protocol]; ok { + l.neighborTable.handleConfirmation(addr, linkAddr, flags) + return nil } + + return &tcpip.ErrNotSupported{} } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 504acc246..e9acef6a2 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -19,7 +19,9 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -52,66 +54,146 @@ func (f *fakeRand) Float32() float32 { return f.num } -// TestSetNUDConfigurationFailsForBadNICID tests to make sure we get an error if -// we attempt to update NUD configurations using an invalid NICID. -func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The networking - // stack will only allocate neighbor caches if a protocol providing link - // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - UseNeighborCache: true, - }) +func TestNUDFunctions(t *testing.T) { + const nicID = 1 - // No NIC with ID 1 yet. - config := stack.NUDConfigurations{} - err := s.SetNUDConfigurations(1, config) - if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { - t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, &tcpip.ErrUnknownNICID{}) + tests := []struct { + name string + nicID tcpip.NICID + netProtoFactory []stack.NetworkProtocolFactory + extraLinkCapabilities stack.LinkEndpointCapabilities + expectedErr tcpip.Error + }{ + { + name: "Invalid NICID", + nicID: nicID + 1, + netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + extraLinkCapabilities: stack.CapabilityResolutionRequired, + expectedErr: &tcpip.ErrUnknownNICID{}, + }, + { + name: "No network protocol", + nicID: nicID, + expectedErr: &tcpip.ErrNotSupported{}, + }, + { + name: "With IPv6", + nicID: nicID, + netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + expectedErr: &tcpip.ErrNotSupported{}, + }, + { + name: "With resolution capability", + nicID: nicID, + extraLinkCapabilities: stack.CapabilityResolutionRequired, + expectedErr: &tcpip.ErrNotSupported{}, + }, + { + name: "With IPv6 and resolution capability", + nicID: nicID, + netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + extraLinkCapabilities: stack.CapabilityResolutionRequired, + }, } -} -// TestNUDConfigurationFailsForNotSupported tests to make sure we get a -// NotSupported error if we attempt to retrieve or set NUD configurations when -// the stack doesn't support NUD. -// -// The stack will report to not support NUD if a neighbor cache for a given NIC -// is not allocated. The networking stack will only allocate neighbor caches if -// the NIC requires link resolution. -func TestNUDConfigurationFailsForNotSupported(t *testing.T) { - const nicID = 1 + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NUDConfigs: stack.DefaultNUDConfigurations(), + UseNeighborCache: true, + NetworkProtocols: test.netProtoFactory, + Clock: clock, + }) - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities &^= stack.CapabilityResolutionRequired + e := channel.New(0, 0, linkAddr1) + e.LinkEPCapabilities &^= stack.CapabilityResolutionRequired + e.LinkEPCapabilities |= test.extraLinkCapabilities - s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), - UseNeighborCache: true, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - t.Run("Get", func(t *testing.T) { - _, err := s.NUDConfigurations(nicID) - if _, ok := err.(*tcpip.ErrNotSupported); !ok { - t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, &tcpip.ErrNotSupported{}) - } - }) + configs := stack.DefaultNUDConfigurations() + configs.BaseReachableTime = time.Hour - t.Run("Set", func(t *testing.T) { - config := stack.NUDConfigurations{} - err := s.SetNUDConfigurations(nicID, config) - if _, ok := err.(*tcpip.ErrNotSupported); !ok { - t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, &tcpip.ErrNotSupported{}) - } - }) + { + err := s.SetNUDConfigurations(test.nicID, ipv6.ProtocolNumber, configs) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.SetNUDConfigurations(%d, %d, _) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } + } + + { + gotConfigs, err := s.NUDConfigurations(test.nicID, ipv6.ProtocolNumber) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.NUDConfigurations(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } else if test.expectedErr == nil { + if diff := cmp.Diff(configs, gotConfigs); diff != "" { + t.Errorf("got configs mismatch (-want +got):\n%s", diff) + } + } + } + + for _, addr := range []tcpip.Address{llAddr1, llAddr2} { + { + err := s.AddStaticNeighbor(test.nicID, ipv6.ProtocolNumber, addr, linkAddr1) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.AddStaticNeighbor(%d, %d, %s, %s) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, addr, linkAddr1, diff) + } + } + } + + { + wantErr := test.expectedErr + for i := 0; i < 2; i++ { + { + err := s.RemoveNeighbor(test.nicID, ipv6.ProtocolNumber, llAddr1) + if diff := cmp.Diff(wantErr, err); diff != "" { + t.Errorf("s.RemoveNeighbor(%d, %d, '') error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } + } + + if test.expectedErr != nil { + break + } + + // Removing a neighbor that does not exist should give us a bad address + // error. + wantErr = &tcpip.ErrBadAddress{} + } + } + + { + neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } else if test.expectedErr == nil { + if diff := cmp.Diff( + []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + neighbors, + ); diff != "" { + t.Errorf("neighbors mismatch (-want +got):\n%s", diff) + } + } + } + + { + err := s.ClearNeighbors(test.nicID, ipv6.ProtocolNumber) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.ClearNeigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } else if test.expectedErr == nil { + if neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber); err != nil { + t.Errorf("s.Neighbors(%d, %d): %s", test.nicID, ipv6.ProtocolNumber, err) + } else if len(neighbors) != 0 { + t.Errorf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + } + } + } + }) + } } -// TestDefaultNUDConfigurationIsValid verifies that calling -// resetInvalidFields() on the result of DefaultNUDConfigurations() does not -// change anything. DefaultNUDConfigurations() should return a valid -// NUDConfigurations. func TestDefaultNUDConfigurations(t *testing.T) { const nicID = 1 @@ -129,12 +211,12 @@ func TestDefaultNUDConfigurations(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - c, err := s.NUDConfigurations(nicID) + c, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got, want := c, stack.DefaultNUDConfigurations(); got != want { - t.Errorf("got stack.NUDConfigurations(%d) = %+v, want = %+v", nicID, got, want) + t.Errorf("got stack.NUDConfigurations(%d, %d) = %+v, want = %+v", nicID, ipv6.ProtocolNumber, got, want) } } @@ -184,9 +266,9 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.BaseReachableTime; got != test.want { t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want) @@ -241,9 +323,9 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MinRandomFactor; got != test.want { t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want) @@ -321,9 +403,9 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MaxRandomFactor; got != test.want { t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want) @@ -383,9 +465,9 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.RetransmitTimer; got != test.want { t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want) @@ -435,9 +517,9 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.DelayFirstProbeTime; got != test.want { t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want) @@ -487,9 +569,9 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MaxMulticastProbes; got != test.want { t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want) @@ -539,9 +621,9 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MaxUnicastProbes; got != test.want { t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index c652c2bd7..e02f7190c 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -536,11 +536,11 @@ type NetworkInterface interface { // // HandleNeighborProbe assumes that the probe is valid for the network // interface the probe was received on. - HandleNeighborProbe(tcpip.Address, tcpip.LinkAddress, LinkAddressResolver) + HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error // HandleNeighborConfirmation processes an incoming neighbor confirmation // (e.g. ARP reply or NDP Neighbor Advertisement). - HandleNeighborConfirmation(tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) + HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) tcpip.Error } // LinkResolvableNetworkEndpoint handles link resolution events. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 1c8ef6ed4..bab55ce49 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -53,7 +53,7 @@ type Route struct { // linkRes is set if link address resolution is enabled for this protocol on // the route's NIC. - linkRes LinkAddressResolver + linkRes linkResolver } type routeInfo struct { @@ -184,11 +184,11 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA return r } - if r.linkRes == nil { + if r.linkRes.resolver == nil { return r } - if linkAddr, ok := r.linkRes.ResolveStaticAddress(r.RemoteAddress); ok { + if linkAddr, ok := r.linkRes.resolver.ResolveStaticAddress(r.RemoteAddress); ok { r.ResolveWith(linkAddr) return r } @@ -362,7 +362,7 @@ func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteIn } afterResolveFields := fields - linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, r.linkRes, func(r LinkResolutionResult) { + linkAddr, ch, err := r.linkRes.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, func(r LinkResolutionResult) { if afterResolve != nil { if r.Success { afterResolveFields.RemoteLinkAddress = r.LinkAddress @@ -400,7 +400,7 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isResolutionRequiredRLocked() bool { - return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local() + return len(r.mu.remoteLinkAddress) == 0 && r.linkRes.resolver != nil && r.isValidForOutgoingRLocked() && !r.local() } func (r *Route) isValidForOutgoing() bool { @@ -528,5 +528,7 @@ func (r *Route) IsOutboundBroadcast() bool { // "Reachable" is defined as having full-duplex communication between the // local and remote ends of the route. func (r *Route) ConfirmReachable() { - r.outgoingNIC.confirmReachable(r.nextHop()) + if r.linkRes.resolver != nil { + r.linkRes.confirmReachable(r.nextHop()) + } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 73db6e031..9390aaf57 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1560,7 +1560,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, } // Neighbors returns all IP to MAC address associations. -func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, tcpip.Error) { +func (s *Stack) Neighbors(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1569,11 +1569,11 @@ func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, tcpip.Error) { return nil, &tcpip.ErrUnknownNICID{} } - return nic.neighbors() + return nic.neighbors(protocol) } // AddStaticNeighbor statically associates an IP address to a MAC address. -func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { +func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1582,13 +1582,13 @@ func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAdd return &tcpip.ErrUnknownNICID{} } - return nic.addStaticNeighbor(addr, linkAddr) + return nic.addStaticNeighbor(addr, protocol, linkAddr) } // RemoveNeighbor removes an IP to MAC address association previously created // either automically or by AddStaticNeighbor. Returns ErrBadAddress if there // is no association with the provided address. -func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) tcpip.Error { +func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1597,11 +1597,11 @@ func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) tcpip.Erro return &tcpip.ErrUnknownNICID{} } - return nic.removeNeighbor(addr) + return nic.removeNeighbor(protocol, addr) } // ClearNeighbors removes all IP to MAC address associations. -func (s *Stack) ClearNeighbors(nicID tcpip.NICID) tcpip.Error { +func (s *Stack) ClearNeighbors(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1610,7 +1610,7 @@ func (s *Stack) ClearNeighbors(nicID tcpip.NICID) tcpip.Error { return &tcpip.ErrUnknownNICID{} } - return nic.clearNeighbors() + return nic.clearNeighbors(protocol) } // RegisterTransportEndpoint registers the given endpoint with the stack @@ -1980,7 +1980,7 @@ func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtoco } // NUDConfigurations gets the per-interface NUD configurations. -func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, tcpip.Error) { +func (s *Stack) NUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() @@ -1989,14 +1989,14 @@ func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, tcpip.Erro return NUDConfigurations{}, &tcpip.ErrUnknownNICID{} } - return nic.nudConfigs() + return nic.nudConfigs(proto) } // SetNUDConfigurations sets the per-interface NUD configurations. // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) tcpip.Error { +func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() @@ -2005,7 +2005,7 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) tcpip. return &tcpip.ErrUnknownNICID{} } - return nic.setNUDConfigs(c) + return nic.setNUDConfigs(proto, c) } // Seed returns a 32 bit value that can be used as a seed value for port diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index a166c0502..375cd3080 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -4313,9 +4314,11 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") ) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, UseNeighborCache: true, + Clock: clock, }) e := channel.New(0, 0, "") e.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -4323,36 +4326,56 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddStaticNeighbor(nicID, ipv4Addr, linkAddr); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv4Addr, linkAddr, err) - } - if err := s.AddStaticNeighbor(nicID, ipv6Addr, linkAddr); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv6Addr, linkAddr, err) + addrs := []struct { + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + }{ + { + proto: ipv4.ProtocolNumber, + addr: ipv4Addr, + }, + { + proto: ipv6.ProtocolNumber, + addr: ipv6Addr, + }, } - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) - } else if len(neighbors) != 2 { - t.Fatalf("got len(neighbors) = %d, want = 2; neighbors = %#v", len(neighbors), neighbors) + for _, addr := range addrs { + if err := s.AddStaticNeighbor(nicID, addr.proto, addr.addr, linkAddr); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, addr.proto, addr.addr, linkAddr, err) + } + + if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) + } else if diff := cmp.Diff( + []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + neighbors, + ); diff != "" { + t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff) + } } // Disabling the NIC should clear the neighbor table. if err := s.DisableNIC(nicID); err != nil { t.Fatalf("s.DisableNIC(%d): %s", nicID, err) } - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) - } else if len(neighbors) != 0 { - t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + for _, addr := range addrs { + if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) + } else if len(neighbors) != 0 { + t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) + } } // Enabling the NIC should have an empty neighbor table. if err := s.EnableNIC(nicID); err != nil { t.Fatalf("s.EnableNIC(%d): %s", nicID, err) } - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) - } else if len(neighbors) != 0 { - t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + for _, addr := range addrs { + if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) + } else if len(neighbors) != 0 { + t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) + } } } diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 7069352f2..b3a5d49d7 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -1069,9 +1069,9 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { // Wait for the remote's neighbor entry to be stale before creating a // TCP connection from host1 to some remote. - nudConfigs, err := host1Stack.NUDConfigurations(host1NICID) + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) if err != nil { - t.Fatalf("host1Stack.NUDConfigurations(%d): %s", host1NICID, err) + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) } // The maximum reachable time for a neighbor is some maximum random factor // applied to the base reachable time. |