summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/network/arp/arp.go18
-rw-r--r--pkg/tcpip/network/arp/arp_test.go4
-rw-r--r--pkg/tcpip/network/ip_test.go6
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go38
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go6
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go20
-rw-r--r--pkg/tcpip/stack/forwarding_test.go5
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go25
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go95
-rw-r--r--pkg/tcpip/stack/ndp_test.go4
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go23
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go434
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go61
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go25
-rw-r--r--pkg/tcpip/stack/nic.go150
-rw-r--r--pkg/tcpip/stack/nud_test.go218
-rw-r--r--pkg/tcpip/stack/registration.go4
-rw-r--r--pkg/tcpip/stack/route.go14
-rw-r--r--pkg/tcpip/stack/stack.go24
-rw-r--r--pkg/tcpip/stack/stack_test.go57
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go4
21 files changed, 669 insertions, 566 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 5fd4c5574..0d7fadc31 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -148,7 +148,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
remoteAddr := tcpip.Address(h.ProtocolAddressSender())
remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.nic.HandleNeighborProbe(remoteAddr, remoteLinkAddr, e)
+ switch err := e.nic.HandleNeighborProbe(header.IPv4ProtocolNumber, remoteAddr, remoteLinkAddr); err.(type) {
+ case nil:
+ case *tcpip.ErrNotSupported:
+ // The stack may support ARP but the NIC may not need link resolution.
+ default:
+ panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err))
+ }
respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
@@ -190,7 +196,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// The solicited, override, and isRouter flags are not available for ARP;
// they are only available for IPv6 Neighbor Advertisements.
- e.nic.HandleNeighborConfirmation(addr, linkAddr, stack.ReachabilityConfirmationFlags{
+ switch err := e.nic.HandleNeighborConfirmation(header.IPv4ProtocolNumber, addr, linkAddr, stack.ReachabilityConfirmationFlags{
// Solicited and unsolicited (also referred to as gratuitous) ARP Replies
// are handled equivalently to a solicited Neighbor Advertisement.
Solicited: true,
@@ -199,7 +205,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
Override: false,
// ARP does not distinguish between router and non-router hosts.
IsRouter: false,
- })
+ }); err.(type) {
+ case nil:
+ case *tcpip.ErrNotSupported:
+ // The stack may support ARP but the NIC may not need link resolution.
+ default:
+ panic(fmt.Sprintf("unexpected error when informing NIC of neighbor confirmation message: %s", err))
+ }
}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index d753a97af..24357e15d 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -491,9 +491,9 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
t.Fatal(err)
}
- neighbors, err := c.s.Neighbors(nicID)
+ neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber)
if err != nil {
- t.Fatalf("c.s.Neighbors(%d): %s", nicID, err)
+ t.Fatalf("c.s.Neighbors(%d, %d): %s", nicID, ipv4.ProtocolNumber, err)
}
neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 291330e8e..8d155344b 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -311,10 +311,12 @@ func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.N
return &tcpip.ErrNotSupported{}
}
-func (*testInterface) HandleNeighborProbe(tcpip.Address, tcpip.LinkAddress, stack.LinkAddressResolver) {
+func (*testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error {
+ return nil
}
-func (*testInterface) HandleNeighborConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) {
+func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error {
+ return nil
}
func TestSourceAddressValidation(t *testing.T) {
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index bdc88fe5d..12e5ead5e 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -290,7 +290,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
received.invalid.Increment()
return
} else {
- e.nic.HandleNeighborProbe(srcAddr, sourceLinkAddr, e)
+ switch err := e.nic.HandleNeighborProbe(ProtocolNumber, srcAddr, sourceLinkAddr); err.(type) {
+ case nil:
+ case *tcpip.ErrNotSupported:
+ // The stack may support ICMPv6 but the NIC may not need link resolution.
+ default:
+ panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err))
+ }
}
// As per RFC 4861 section 7.1.1:
@@ -456,11 +462,17 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// If the NA message has the target link layer option, update the link
// address cache with the link address for the target of the message.
- e.nic.HandleNeighborConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{
+ switch err := e.nic.HandleNeighborConfirmation(ProtocolNumber, targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{
Solicited: na.SolicitedFlag(),
Override: na.OverrideFlag(),
IsRouter: na.RouterFlag(),
- })
+ }); err.(type) {
+ case nil:
+ case *tcpip.ErrNotSupported:
+ // The stack may support ICMPv6 but the NIC may not need link resolution.
+ default:
+ panic(fmt.Sprintf("unexpected error when informing NIC of neighbor confirmation message: %s", err))
+ }
case header.ICMPv6EchoRequest:
received.echoRequest.Increment()
@@ -566,9 +578,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
return
}
- // A RS with a specified source IP address modifies the NUD state
- // machine in the same way a reachability probe would.
- e.nic.HandleNeighborProbe(srcAddr, sourceLinkAddr, e)
+ // A RS with a specified source IP address modifies the neighbor table
+ // in the same way a regular probe would.
+ switch err := e.nic.HandleNeighborProbe(ProtocolNumber, srcAddr, sourceLinkAddr); err.(type) {
+ case nil:
+ case *tcpip.ErrNotSupported:
+ // The stack may support ICMPv6 but the NIC may not need link resolution.
+ default:
+ panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err))
+ }
}
case header.ICMPv6RouterAdvert:
@@ -617,7 +635,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// If the RA has the source link layer option, update the link address
// cache with the link address for the advertised router.
if len(sourceLinkAddr) != 0 {
- e.nic.HandleNeighborProbe(routerAddr, sourceLinkAddr, e)
+ switch err := e.nic.HandleNeighborProbe(ProtocolNumber, routerAddr, sourceLinkAddr); err.(type) {
+ case nil:
+ case *tcpip.ErrNotSupported:
+ // The stack may support ICMPv6 but the NIC may not need link resolution.
+ default:
+ panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err))
+ }
}
e.mu.Lock()
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 755293377..4374d0198 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -139,12 +139,14 @@ func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gs
return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt)
}
-func (t *testInterface) HandleNeighborProbe(tcpip.Address, tcpip.LinkAddress, stack.LinkAddressResolver) {
+func (t *testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error {
t.probeCount++
+ return nil
}
-func (t *testInterface) HandleNeighborConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) {
+func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error {
t.confirmationCount++
+ return nil
}
func TestICMPCounts(t *testing.T) {
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 4cc81e6cc..e0245487b 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -338,18 +338,18 @@ func TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *test
Data: hdr.View().ToVectorisedView(),
}))
- neighbors, err := s.Neighbors(nicID)
+ neighbors, err := s.Neighbors(nicID, ProtocolNumber)
if err != nil {
- t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err)
}
neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
for _, n := range neighbors {
if existing, ok := neighborByAddr[n.Addr]; ok {
if diff := cmp.Diff(existing, n); diff != "" {
- t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
+ t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff)
}
- t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing)
+ t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing)
}
neighborByAddr[n.Addr] = n
}
@@ -907,18 +907,18 @@ func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *tes
Data: hdr.View().ToVectorisedView(),
}))
- neighbors, err := s.Neighbors(nicID)
+ neighbors, err := s.Neighbors(nicID, ProtocolNumber)
if err != nil {
- t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err)
}
neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
for _, n := range neighbors {
if existing, ok := neighborByAddr[n.Addr]; ok {
if diff := cmp.Diff(existing, n); diff != "" {
- t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
+ t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff)
}
- t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing)
+ t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing)
}
neighborByAddr[n.Addr] = n
}
@@ -1277,8 +1277,8 @@ func TestNeighborAdvertisementValidation(t *testing.T) {
// There is no need to create an entry if none exists, since the
// recipient has apparently not initiated any communication with the
// target.
- if neighbors, err := s.Neighbors(nicID); err != nil {
- t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ if neighbors, err := s.Neighbors(nicID, ProtocolNumber); err != nil {
+ t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err)
} else if len(neighbors) != 0 {
t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors)
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 704812641..c24f56ece 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -400,7 +400,10 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC
if !ok {
t.Fatal("NIC 2 does not exist")
}
- proto.neighborTable = nic.neighborTable
+
+ if l, ok := nic.linkAddrResolvers[fwdTestNetNumber]; ok {
+ proto.neighborTable = l.neighborTable
+ }
// Route all packets to NIC 2.
{
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 4504db752..5b6b58b1d 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -32,6 +32,8 @@ const linkAddrCacheSize = 512 // max cache entries
type linkAddrCache struct {
nic *NIC
+ linkRes LinkAddressResolver
+
// ageLimit is how long a cache entry is valid for.
ageLimit time.Duration
@@ -196,10 +198,10 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry {
return entry
}
-// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
+// get reports any known link address for addr.
+func (c *linkAddrCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
c.mu.Lock()
- entry := c.getOrCreateEntryLocked(k)
+ entry := c.getOrCreateEntryLocked(addr)
entry.mu.Lock()
defer entry.mu.Unlock()
c.mu.Unlock()
@@ -222,7 +224,7 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA
}
if entry.mu.done == nil {
entry.mu.done = make(chan struct{})
- go c.startAddressResolution(k, linkRes, localAddr, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ go c.startAddressResolution(addr, localAddr, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
}
return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{}
default:
@@ -230,11 +232,11 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA
}
}
-func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, done <-chan struct{}) {
+func (c *linkAddrCache) startAddressResolution(k tcpip.Address, localAddr tcpip.Address, done <-chan struct{}) {
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */)
+ c.linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */)
select {
case now := <-time.After(c.resolutionTimeout):
@@ -278,15 +280,18 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt
return true
}
-func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
- c := &linkAddrCache{
+func (c *linkAddrCache) init(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int, linkRes LinkAddressResolver) {
+ *c = linkAddrCache{
nic: nic,
+ linkRes: linkRes,
ageLimit: ageLimit,
resolutionTimeout: resolutionTimeout,
resolutionAttempts: resolutionAttempts,
}
+
+ c.mu.Lock()
c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize)
- return c
+ c.mu.Unlock()
}
var _ neighborTable = (*linkAddrCache)(nil)
@@ -307,7 +312,7 @@ func (*linkAddrCache) removeAll() tcpip.Error {
return &tcpip.ErrNotSupported{}
}
-func (c *linkAddrCache) handleProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress, _ LinkAddressResolver) {
+func (c *linkAddrCache) handleProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress) {
if len(linkAddr) != 0 {
// NUD allows probes without a link address but linkAddrCache
// is a simple neighbor table which does not implement NUD.
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 4df6f9265..9e7f331c9 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -77,10 +77,10 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe
return 1
}
-func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolver) (tcpip.LinkAddress, tcpip.Error) {
+func getBlocking(c *linkAddrCache, addr tcpip.Address) (tcpip.LinkAddress, tcpip.Error) {
var attemptedResolution bool
for {
- got, ch, err := c.get(addr, linkRes, "", nil)
+ got, ch, err := c.get(addr, "", nil)
if _, ok := err.(*tcpip.ErrWouldBlock); ok {
if attemptedResolution {
return got, &tcpip.ErrTimeout{}
@@ -100,27 +100,28 @@ func newEmptyNIC() *NIC {
}
func TestCacheOverflow(t *testing.T) {
- c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3)
+ var c linkAddrCache
+ c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil)
for i := len(testAddrs) - 1; i >= 0; i-- {
e := testAddrs[i]
c.add(e.addr, e.linkAddr)
- got, _, err := c.get(e.addr, nil, "", nil)
+ got, _, err := c.get(e.addr, "", nil)
if err != nil {
- t.Errorf("insert %d, c.get(%s, nil, '', nil): %s", i, e.addr, err)
+ t.Errorf("insert %d, c.get(%s, '', nil): %s", i, e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("insert %d, got c.get(%s, nil, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
+ t.Errorf("insert %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
}
}
// Expect to find at least half of the most recent entries.
for i := 0; i < linkAddrCacheSize/2; i++ {
e := testAddrs[i]
- got, _, err := c.get(e.addr, nil, "", nil)
+ got, _, err := c.get(e.addr, "", nil)
if err != nil {
- t.Errorf("check %d, c.get(%s, nil, '', nil): %s", i, e.addr, err)
+ t.Errorf("check %d, c.get(%s, '', nil): %s", i, e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("check %d, got c.get(%s, nil, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
+ t.Errorf("check %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
}
}
// The earliest entries should no longer be in the cache.
@@ -135,8 +136,9 @@ func TestCacheOverflow(t *testing.T) {
}
func TestCacheConcurrent(t *testing.T) {
- c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3)
- linkRes := &testLinkAddressResolver{cache: c}
+ var c linkAddrCache
+ linkRes := &testLinkAddressResolver{cache: &c}
+ c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, linkRes)
var wg sync.WaitGroup
for r := 0; r < 16; r++ {
@@ -154,12 +156,12 @@ func TestCacheConcurrent(t *testing.T) {
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
e := testAddrs[len(testAddrs)-1]
- got, _, err := c.get(e.addr, linkRes, "", nil)
+ got, _, err := c.get(e.addr, "", nil)
if err != nil {
- t.Errorf("c.get(%s, _, '', nil): %s", e.addr, err)
+ t.Errorf("c.get(%s, '', nil): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("got c.get(%s, _, '', nil) = %s, want = %s", e.addr, got, e.linkAddr)
+ t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr)
}
e = testAddrs[0]
@@ -171,38 +173,40 @@ func TestCacheConcurrent(t *testing.T) {
}
func TestCacheAgeLimit(t *testing.T) {
- c := newLinkAddrCache(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3)
- linkRes := &testLinkAddressResolver{cache: c}
+ var c linkAddrCache
+ linkRes := &testLinkAddressResolver{cache: &c}
+ c.init(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3, linkRes)
e := testAddrs[0]
c.add(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
- _, _, err := c.get(e.addr, linkRes, "", nil)
+ _, _, err := c.get(e.addr, "", nil)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got c.get(%s, _, '', nil) = %s, want = ErrWouldBlock", e.addr, err)
+ t.Errorf("got c.get(%s, '', nil) = %s, want = ErrWouldBlock", e.addr, err)
}
}
func TestCacheReplace(t *testing.T) {
- c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3)
+ var c linkAddrCache
+ c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil)
e := testAddrs[0]
l2 := e.linkAddr + "2"
c.add(e.addr, e.linkAddr)
- got, _, err := c.get(e.addr, nil, "", nil)
+ got, _, err := c.get(e.addr, "", nil)
if err != nil {
- t.Errorf("c.get(%s, nil, '', nil): %s", e.addr, err)
+ t.Errorf("c.get(%s, '', nil): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("got c.get(%s, nil, '', nil) = %s, want = %s", e.addr, got, e.linkAddr)
+ t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr)
}
c.add(e.addr, l2)
- got, _, err = c.get(e.addr, nil, "", nil)
+ got, _, err = c.get(e.addr, "", nil)
if err != nil {
- t.Errorf("c.get(%s, nil, '', nil): %s", e.addr, err)
+ t.Errorf("c.get(%s, '', nil): %s", e.addr, err)
}
if got != l2 {
- t.Errorf("got c.get(%s, nil, '', nil) = %s, want = %s", e.addr, got, l2)
+ t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, l2)
}
}
@@ -213,34 +217,36 @@ func TestCacheResolution(t *testing.T) {
//
// Using a large resolution timeout decreases the probability of experiencing
// this race condition and does not affect how long this test takes to run.
- c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1)
- linkRes := &testLinkAddressResolver{cache: c}
+ var c linkAddrCache
+ linkRes := &testLinkAddressResolver{cache: &c}
+ c.init(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1, linkRes)
for i, ta := range testAddrs {
- got, err := getBlocking(c, ta.addr, linkRes)
+ got, err := getBlocking(&c, ta.addr)
if err != nil {
- t.Errorf("check %d, getBlocking(_, %s, _): %s", i, ta.addr, err)
+ t.Errorf("check %d, getBlocking(_, %s): %s", i, ta.addr, err)
}
if got != ta.linkAddr {
- t.Errorf("check %d, got getBlocking(_, %s, _) = %s, want = %s", i, ta.addr, got, ta.linkAddr)
+ t.Errorf("check %d, got getBlocking(_, %s) = %s, want = %s", i, ta.addr, got, ta.linkAddr)
}
}
// Check that after resolved, address stays in the cache and never returns WouldBlock.
for i := 0; i < 10; i++ {
e := testAddrs[len(testAddrs)-1]
- got, _, err := c.get(e.addr, linkRes, "", nil)
+ got, _, err := c.get(e.addr, "", nil)
if err != nil {
- t.Errorf("c.get(%s, _, '', nil): %s", e.addr, err)
+ t.Errorf("c.get(%s, '', nil): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("got c.get(%s, _, '', nil) = %s, want = %s", e.addr, got, e.linkAddr)
+ t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr)
}
}
}
func TestCacheResolutionFailed(t *testing.T) {
- c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5)
- linkRes := &testLinkAddressResolver{cache: c}
+ var c linkAddrCache
+ linkRes := &testLinkAddressResolver{cache: &c}
+ c.init(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5, linkRes)
var requestCount uint32
linkRes.onLinkAddressRequest = func() {
@@ -249,20 +255,20 @@ func TestCacheResolutionFailed(t *testing.T) {
// First, sanity check that resolution is working...
e := testAddrs[0]
- got, err := getBlocking(c, e.addr, linkRes)
+ got, err := getBlocking(&c, e.addr)
if err != nil {
- t.Errorf("getBlocking(_, %s, _): %s", e.addr, err)
+ t.Errorf("getBlocking(_, %s): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("got getBlocking(_, %s, _) = %s, want = %s", e.addr, got, e.linkAddr)
+ t.Errorf("got getBlocking(_, %s) = %s, want = %s", e.addr, got, e.linkAddr)
}
before := atomic.LoadUint32(&requestCount)
e.addr += "2"
- a, err := getBlocking(c, e.addr, linkRes)
+ a, err := getBlocking(&c, e.addr)
if _, ok := err.(*tcpip.ErrTimeout); !ok {
- t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{})
+ t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{})
}
if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
@@ -273,12 +279,13 @@ func TestCacheResolutionFailed(t *testing.T) {
func TestCacheResolutionTimeout(t *testing.T) {
resolverDelay := 500 * time.Millisecond
expiration := resolverDelay / 10
- c := newLinkAddrCache(newEmptyNIC(), expiration, 1*time.Millisecond, 3)
- linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
+ var c linkAddrCache
+ linkRes := &testLinkAddressResolver{cache: &c, delay: resolverDelay}
+ c.init(newEmptyNIC(), expiration, 1*time.Millisecond, 3, linkRes)
e := testAddrs[0]
- a, err := getBlocking(c, e.addr, linkRes)
+ a, err := getBlocking(&c, e.addr)
if _, ok := err.(*tcpip.ErrTimeout); !ok {
- t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{})
+ t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{})
}
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index c13be137e..0238605af 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -2796,8 +2796,8 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN
NIC: nicID,
}})
- if err := s.AddStaticNeighbor(nicID, llAddr3, linkAddr3); err != nil {
- t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err)
+ if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil {
+ t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err)
}
return ndpDisp, e, s
}
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 64b8046f5..7e3132058 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -43,8 +43,9 @@ type NeighborStats struct {
// Their state is always Static. The amount of static entries stored in the
// cache is unbounded.
type neighborCache struct {
- nic *NIC
- state *NUDState
+ nic *NIC
+ state *NUDState
+ linkRes LinkAddressResolver
// mu protects the fields below.
mu sync.RWMutex
@@ -69,7 +70,7 @@ type neighborCache struct {
// reset to state incomplete, and returned. If no matching entry exists and the
// cache is not full, a new entry with state incomplete is allocated and
// returned.
-func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry {
+func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address) *neighborEntry {
n.mu.Lock()
defer n.mu.Unlock()
@@ -85,7 +86,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
// The entry that needs to be created must be dynamic since all static
// entries are directly added to the cache via addStaticEntry.
- entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes)
+ entry := newNeighborEntry(n, remoteAddr, n.state)
if n.dynamic.count == neighborCacheSize {
e := n.dynamic.lru.Back()
e.mu.Lock()
@@ -122,8 +123,8 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
// packet prompting NUD/link address resolution.
//
// TODO(gvisor.dev/issue/5151): Don't return the neighbor entry.
-func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) {
- entry := n.getOrCreateEntry(remoteAddr, linkRes)
+func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) {
+ entry := n.getOrCreateEntry(remoteAddr)
entry.mu.Lock()
defer entry.mu.Unlock()
@@ -202,7 +203,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
entry.mu.Unlock()
}
- n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
+ n.cache[addr] = newStaticNeighborEntry(n, addr, linkAddr, n.state)
}
// removeEntry removes a dynamic or static entry by address from the neighbor
@@ -265,8 +266,8 @@ func (n *neighborCache) neighbors() ([]NeighborEntry, tcpip.Error) {
return n.entries(), nil
}
-func (n *neighborCache) get(addr tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
- entry, ch, err := n.entry(addr, localAddr, linkRes, onResolve)
+func (n *neighborCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
+ entry, ch, err := n.entry(addr, localAddr, onResolve)
return entry.LinkAddr, ch, err
}
@@ -286,8 +287,8 @@ func (n *neighborCache) removeAll() tcpip.Error {
// handleProbe handles a neighbor probe as defined by RFC 4861 section 7.2.3.
//
// Validation of the probe is expected to be handled by the caller.
-func (n *neighborCache) handleProbe(remoteAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) {
- entry := n.getOrCreateEntry(remoteAddr, linkRes)
+func (n *neighborCache) handleProbe(remoteAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ entry := n.getOrCreateEntry(remoteAddr)
entry.mu.Lock()
entry.handleProbeLocked(remoteLinkAddr)
entry.mu.Unlock()
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index 122888fcf..b489b5e08 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -76,10 +76,15 @@ func entryDiffOptsWithSort() []cmp.Option {
}))
}
-func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache {
+func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *testNeighborResolver {
config.resetInvalidFields()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
- neigh := &neighborCache{
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ entries: newTestEntryStore(),
+ delay: typicalLatency,
+ }
+ linkRes.neigh = &neighborCache{
nic: &NIC{
stack: &Stack{
clock: clock,
@@ -88,10 +93,11 @@ func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock
id: 1,
stats: makeNICStats(),
},
- state: NewNUDState(config, rng),
- cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ state: NewNUDState(config, rng),
+ linkRes: linkRes,
+ cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
}
- return neigh
+ return linkRes
}
// testEntryStore contains a set of IP to NeighborEntry mappings.
@@ -241,10 +247,10 @@ func TestNeighborCacheGetConfig(t *testing.T) {
nudDisp := testNUDDispatcher{}
c := DefaultNUDConfigurations()
clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, c, clock)
+ linkRes := newTestNeighborResolver(&nudDisp, c, clock)
- if got, want := neigh.config(), c; got != want {
- t.Errorf("got neigh.config() = %+v, want = %+v", got, want)
+ if got, want := linkRes.neigh.config(), c; got != want {
+ t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want)
}
// No events should have been dispatched.
@@ -259,14 +265,14 @@ func TestNeighborCacheSetConfig(t *testing.T) {
nudDisp := testNUDDispatcher{}
c := DefaultNUDConfigurations()
clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, c, clock)
+ linkRes := newTestNeighborResolver(&nudDisp, c, clock)
c.MinRandomFactor = 1
c.MaxRandomFactor = 1
- neigh.setConfig(c)
+ linkRes.neigh.setConfig(c)
- if got, want := neigh.config(), c; got != want {
- t.Errorf("got neigh.config() = %+v, want = %+v", got, want)
+ if got, want := linkRes.neigh.config(), c; got != want {
+ t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want)
}
// No events should have been dispatched.
@@ -281,22 +287,15 @@ func TestNeighborCacheEntry(t *testing.T) {
c := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, c, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- }
+ linkRes := newTestNeighborResolver(&nudDisp, c, clock)
- entry, ok := store.entry(0)
+ entry, ok := linkRes.entries.entry(0)
if !ok {
- t.Fatal("store.entry(0) not found")
+ t.Fatal("linkRes.entries.entry(0) not found")
}
- _, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ _, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
@@ -328,8 +327,8 @@ func TestNeighborCacheEntry(t *testing.T) {
t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
+ if _, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err != nil {
+ t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
}
// No more events should have been dispatched.
@@ -345,23 +344,16 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
nudDisp := testNUDDispatcher{}
clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- }
+ linkRes := newTestNeighborResolver(&nudDisp, config, clock)
- entry, ok := store.entry(0)
+ entry, ok := linkRes.entries.entry(0)
if !ok {
- t.Fatal("store.entry(0) not found")
+ t.Fatal("linkRes.entries.entry(0) not found")
}
- _, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ _, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
@@ -393,7 +385,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
- neigh.removeEntry(entry.Addr)
+ linkRes.neigh.removeEntry(entry.Addr)
{
wantEvents := []testEntryEventInfo{
@@ -416,17 +408,15 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
}
{
- _, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ _, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
}
}
type testContext struct {
clock *faketime.ManualClock
- neigh *neighborCache
- store *testEntryStore
linkRes *testNeighborResolver
nudDisp *testNUDDispatcher
}
@@ -434,19 +424,10 @@ type testContext struct {
func newTestContext(c NUDConfigurations) testContext {
nudDisp := &testNUDDispatcher{}
clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(nudDisp, c, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- }
+ linkRes := newTestNeighborResolver(nudDisp, c, clock)
return testContext{
clock: clock,
- neigh: neigh,
- store: store,
linkRes: linkRes,
nudDisp: nudDisp,
}
@@ -460,17 +441,17 @@ type overflowOptions struct {
func (c *testContext) overflowCache(opts overflowOptions) error {
// Fill the neighbor cache to capacity to verify the LRU eviction strategy is
// working properly after the entry removal.
- for i := opts.startAtEntryIndex; i < c.store.size(); i++ {
+ for i := opts.startAtEntryIndex; i < c.linkRes.entries.size(); i++ {
// Add a new entry
- entry, ok := c.store.entry(i)
+ entry, ok := c.linkRes.entries.entry(i)
if !ok {
- return fmt.Errorf("c.store.entry(%d) not found", i)
+ return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i)
}
- _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
+ _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ return fmt.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
- c.clock.Advance(c.neigh.config().RetransmitTimer)
+ c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer)
var wantEvents []testEntryEventInfo
@@ -478,9 +459,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
// LRU eviction strategy. Note that the number of static entries should not
// affect the total number of dynamic entries that can be added.
if i >= neighborCacheSize+opts.startAtEntryIndex {
- removedEntry, ok := c.store.entry(i - neighborCacheSize)
+ removedEntry, ok := c.linkRes.entries.entry(i - neighborCacheSize)
if !ok {
- return fmt.Errorf("store.entry(%d) not found", i-neighborCacheSize)
+ return fmt.Errorf("linkRes.entries.entry(%d) not found", i-neighborCacheSize)
}
wantEvents = append(wantEvents, testEntryEventInfo{
EventType: entryTestRemoved,
@@ -523,10 +504,10 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
// by entries() is nondeterministic, so entries have to be sorted before
// comparison.
wantUnsortedEntries := opts.wantStaticEntries
- for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ {
- entry, ok := c.store.entry(i)
+ for i := c.linkRes.entries.size() - neighborCacheSize; i < c.linkRes.entries.size(); i++ {
+ entry, ok := c.linkRes.entries.entry(i)
if !ok {
- return fmt.Errorf("c.store.entry(%d) not found", i)
+ return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i)
}
wantEntry := NeighborEntry{
Addr: entry.Addr,
@@ -536,7 +517,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
- if diff := cmp.Diff(wantUnsortedEntries, c.neigh.entries(), entryDiffOptsWithSort()...); diff != "" {
+ if diff := cmp.Diff(wantUnsortedEntries, c.linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" {
return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff)
}
@@ -580,15 +561,15 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
c := newTestContext(config)
// Add a dynamic entry
- entry, ok := c.store.entry(0)
+ entry, ok := c.linkRes.entries.entry(0)
if !ok {
- t.Fatal("c.store.entry(0) not found")
+ t.Fatal("c.linkRes.entries.entry(0) not found")
}
- _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
+ _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
- c.clock.Advance(c.neigh.config().RetransmitTimer)
+ c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -617,7 +598,7 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
}
// Remove the entry
- c.neigh.removeEntry(entry.Addr)
+ c.linkRes.neigh.removeEntry(entry.Addr)
{
wantEvents := []testEntryEventInfo{
@@ -656,12 +637,12 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
c := newTestContext(config)
// Add a static entry
- entry, ok := c.store.entry(0)
+ entry, ok := c.linkRes.entries.entry(0)
if !ok {
- t.Fatal("c.store.entry(0) not found")
+ t.Fatal("c.linkRes.entries.entry(0) not found")
}
staticLinkAddr := entry.LinkAddr + "static"
- c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -682,7 +663,7 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
}
// Remove the static entry that was just added
- c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
// No more events should have been dispatched.
c.nudDisp.mu.Lock()
@@ -700,12 +681,12 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
c := newTestContext(config)
// Add a static entry
- entry, ok := c.store.entry(0)
+ entry, ok := c.linkRes.entries.entry(0)
if !ok {
- t.Fatal("c.store.entry(0) not found")
+ t.Fatal("c.linkRes.entries.entry(0) not found")
}
staticLinkAddr := entry.LinkAddr + "static"
- c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -727,7 +708,7 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
// Add a duplicate entry with a different link address
staticLinkAddr += "duplicate"
- c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
{
wantEvents := []testEntryEventInfo{
{
@@ -762,12 +743,12 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
c := newTestContext(config)
// Add a static entry
- entry, ok := c.store.entry(0)
+ entry, ok := c.linkRes.entries.entry(0)
if !ok {
- t.Fatal("c.store.entry(0) not found")
+ t.Fatal("c.linkRes.entries.entry(0) not found")
}
staticLinkAddr := entry.LinkAddr + "static"
- c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -788,7 +769,7 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
}
// Remove the static entry that was just added
- c.neigh.removeEntry(entry.Addr)
+ c.linkRes.neigh.removeEntry(entry.Addr)
{
wantEvents := []testEntryEventInfo{
{
@@ -832,13 +813,13 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
c := newTestContext(config)
// Add a dynamic entry
- entry, ok := c.store.entry(0)
+ entry, ok := c.linkRes.entries.entry(0)
if !ok {
- t.Fatal("c.store.entry(0) not found")
+ t.Fatal("c.linkRes.entries.entry(0) not found")
}
- _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
+ _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
@@ -870,7 +851,7 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
// Override the entry with a static one using the same address
staticLinkAddr := entry.LinkAddr + "static"
- c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
{
wantEvents := []testEntryEventInfo{
{
@@ -925,14 +906,14 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
c := newTestContext(config)
- entry, ok := c.store.entry(0)
+ entry, ok := c.linkRes.entries.entry(0)
if !ok {
- t.Fatal("c.store.entry(0) not found")
+ t.Fatal("c.linkRes.entries.entry(0) not found")
}
- c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr)
- e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
+ c.linkRes.neigh.addStaticEntry(entry.Addr, entry.LinkAddr)
+ e, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil)
if err != nil {
- t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil, nil): %s", entry.Addr, err)
+ t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -940,7 +921,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
State: Static,
}
if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
- t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
+ t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
wantEvents := []testEntryEventInfo{
@@ -982,23 +963,16 @@ func TestNeighborCacheClear(t *testing.T) {
nudDisp := testNUDDispatcher{}
clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- }
+ linkRes := newTestNeighborResolver(&nudDisp, config, clock)
// Add a dynamic entry.
- entry, ok := store.entry(0)
+ entry, ok := linkRes.entries.entry(0)
if !ok {
- t.Fatal("store.entry(0) not found")
+ t.Fatal("linkRes.entries.entry(0) not found")
}
- _, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ _, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
@@ -1030,7 +1004,7 @@ func TestNeighborCacheClear(t *testing.T) {
}
// Add a static entry.
- neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1)
+ linkRes.neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1)
{
wantEvents := []testEntryEventInfo{
@@ -1054,7 +1028,7 @@ func TestNeighborCacheClear(t *testing.T) {
}
// Clear should remove both dynamic and static entries.
- neigh.clear()
+ linkRes.neigh.clear()
// Remove events dispatched from clear() have no deterministic order so they
// need to be sorted beforehand.
@@ -1098,13 +1072,13 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
c := newTestContext(config)
// Add a dynamic entry
- entry, ok := c.store.entry(0)
+ entry, ok := c.linkRes.entries.entry(0)
if !ok {
- t.Fatal("c.store.entry(0) not found")
+ t.Fatal("c.linkRes.entries.entry(0) not found")
}
- _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
+ _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
@@ -1135,7 +1109,7 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
}
// Clear the cache.
- c.neigh.clear()
+ c.linkRes.neigh.clear()
{
wantEvents := []testEntryEventInfo{
{
@@ -1174,18 +1148,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
nudDisp := testNUDDispatcher{}
clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- }
+ linkRes := newTestNeighborResolver(&nudDisp, config, clock)
- frequentlyUsedEntry, ok := store.entry(0)
+ frequentlyUsedEntry, ok := linkRes.entries.entry(0)
if !ok {
- t.Fatal("store.entry(0) not found")
+ t.Fatal("linkRes.entries.entry(0) not found")
}
// The following logic is very similar to overflowCache, but
@@ -1193,23 +1160,23 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
// Fill the neighbor cache to capacity
for i := 0; i < neighborCacheSize; i++ {
- entry, ok := store.entry(i)
+ entry, ok := linkRes.entries.entry(i)
if !ok {
- t.Fatalf("store.entry(%d) not found", i)
+ t.Fatalf("linkRes.entries.entry(%d) not found", i)
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
select {
case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
wantEvents := []testEntryEventInfo{
{
@@ -1240,38 +1207,38 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
}
// Keep adding more entries
- for i := neighborCacheSize; i < store.size(); i++ {
+ for i := neighborCacheSize; i < linkRes.entries.size(); i++ {
// Periodically refresh the frequently used entry
if i%(neighborCacheSize/2) == 0 {
- if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", frequentlyUsedEntry.Addr, err)
+ if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil {
+ t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", frequentlyUsedEntry.Addr, err)
}
}
- entry, ok := store.entry(i)
+ entry, ok := linkRes.entries.entry(i)
if !ok {
- t.Fatalf("store.entry(%d) not found", i)
+ t.Fatalf("linkRes.entries.entry(%d) not found", i)
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
select {
case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
// An entry should have been removed, as per the LRU eviction strategy
- removedEntry, ok := store.entry(i - neighborCacheSize + 1)
+ removedEntry, ok := linkRes.entries.entry(i - neighborCacheSize + 1)
if !ok {
- t.Fatalf("store.entry(%d) not found", i-neighborCacheSize+1)
+ t.Fatalf("linkRes.entries.entry(%d) not found", i-neighborCacheSize+1)
}
wantEvents := []testEntryEventInfo{
{
@@ -1321,10 +1288,10 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
},
}
- for i := store.size() - neighborCacheSize + 1; i < store.size(); i++ {
- entry, ok := store.entry(i)
+ for i := linkRes.entries.size() - neighborCacheSize + 1; i < linkRes.entries.size(); i++ {
+ entry, ok := linkRes.entries.entry(i)
if !ok {
- t.Fatalf("store.entry(%d) not found", i)
+ t.Fatalf("linkRes.entries.entry(%d) not found", i)
}
wantEntry := NeighborEntry{
Addr: entry.Addr,
@@ -1334,7 +1301,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
- if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" {
+ if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" {
t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff)
}
@@ -1353,26 +1320,19 @@ func TestNeighborCacheConcurrent(t *testing.T) {
nudDisp := testNUDDispatcher{}
clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- }
+ linkRes := newTestNeighborResolver(&nudDisp, config, clock)
- storeEntries := store.entries()
+ storeEntries := linkRes.entries.entries()
for _, entry := range storeEntries {
var wg sync.WaitGroup
for r := 0; r < concurrentProcesses; r++ {
wg.Add(1)
go func(entry NeighborEntry) {
defer wg.Done()
- switch e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err.(type) {
+ switch e, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err.(type) {
case nil, *tcpip.ErrWouldBlock:
default:
- t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{})
+ t.Errorf("got linkRes.neigh.entry(%s, '', nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{})
}
}(entry)
}
@@ -1390,10 +1350,10 @@ func TestNeighborCacheConcurrent(t *testing.T) {
// The order of entries reported by entries() is nondeterministic, so entries
// have to be sorted before comparison.
var wantUnsortedEntries []NeighborEntry
- for i := store.size() - neighborCacheSize; i < store.size(); i++ {
- entry, ok := store.entry(i)
+ for i := linkRes.entries.size() - neighborCacheSize; i < linkRes.entries.size(); i++ {
+ entry, ok := linkRes.entries.entry(i)
if !ok {
- t.Errorf("store.entry(%d) not found", i)
+ t.Errorf("linkRes.entries.entry(%d) not found", i)
}
wantEntry := NeighborEntry{
Addr: entry.Addr,
@@ -1403,7 +1363,7 @@ func TestNeighborCacheConcurrent(t *testing.T) {
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
- if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" {
+ if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" {
t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff)
}
}
@@ -1413,41 +1373,34 @@ func TestNeighborCacheReplace(t *testing.T) {
nudDisp := testNUDDispatcher{}
clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- }
+ linkRes := newTestNeighborResolver(&nudDisp, config, clock)
// Add an entry
- entry, ok := store.entry(0)
+ entry, ok := linkRes.entries.entry(0)
if !ok {
- t.Fatal("store.entry(0) not found")
+ t.Fatal("linkRes.entries.entry(0) not found")
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
select {
case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
// Verify the entry exists
{
- e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ e, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err)
+ t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err)
}
if t.Failed() {
t.FailNow()
@@ -1458,21 +1411,21 @@ func TestNeighborCacheReplace(t *testing.T) {
State: Reachable,
}
if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
+ t.Errorf("linkRes.neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
}
// Notify of a link address change
var updatedLinkAddr tcpip.LinkAddress
{
- entry, ok := store.entry(1)
+ entry, ok := linkRes.entries.entry(1)
if !ok {
- t.Fatal("store.entry(1) not found")
+ t.Fatal("linkRes.entries.entry(1) not found")
}
updatedLinkAddr = entry.LinkAddr
}
- store.set(0, updatedLinkAddr)
- neigh.handleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{
+ linkRes.entries.set(0, updatedLinkAddr)
+ linkRes.neigh.handleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{
Solicited: false,
Override: true,
IsRouter: false,
@@ -1482,9 +1435,9 @@ func TestNeighborCacheReplace(t *testing.T) {
//
// Verify the entry's new link address and the new state.
{
- e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ e, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
if err != nil {
- t.Fatalf("neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
+ t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1492,17 +1445,17 @@ func TestNeighborCacheReplace(t *testing.T) {
State: Delay,
}
if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
+ t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
// Verify that the neighbor is now reachable.
{
- e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ e, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
clock.Advance(typicalLatency)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
+ t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1510,7 +1463,7 @@ func TestNeighborCacheReplace(t *testing.T) {
State: Reachable,
}
if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
+ t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
}
}
@@ -1520,46 +1473,39 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
nudDisp := testNUDDispatcher{}
clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, config, clock)
- store := newTestEntryStore()
+ linkRes := newTestNeighborResolver(&nudDisp, config, clock)
var requestCount uint32
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- onLinkAddressRequest: func() {
- atomic.AddUint32(&requestCount, 1)
- },
+ linkRes.onLinkAddressRequest = func() {
+ atomic.AddUint32(&requestCount, 1)
}
- entry, ok := store.entry(0)
+ entry, ok := linkRes.entries.entry(0)
if !ok {
- t.Fatal("store.entry(0) not found")
+ t.Fatal("linkRes.entries.entry(0) not found")
}
// First, sanity check that resolution is working
{
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
select {
case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
}
- got, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ got, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
+ t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1567,7 +1513,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
State: Reachable,
}
if diff := cmp.Diff(want, got, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
+ t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
// Verify address resolution fails for an unknown address.
@@ -1575,24 +1521,24 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
entry.Addr += "2"
{
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" {
t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
select {
case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
}
- maxAttempts := neigh.config().MaxUnicastProbes
+ maxAttempts := linkRes.neigh.config().MaxUnicastProbes
if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want {
t.Errorf("got link address request count = %d, want = %d", got, want)
}
@@ -1606,27 +1552,22 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
config.RetransmitTimer = time.Millisecond // small enough to cause timeout
clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(nil, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: time.Minute, // large enough to cause timeout
- }
+ linkRes := newTestNeighborResolver(nil, config, clock)
+ // large enough to cause timeout
+ linkRes.delay = time.Minute
- entry, ok := store.entry(0)
+ entry, ok := linkRes.entries.entry(0)
if !ok {
- t.Fatal("store.entry(0) not found")
+ t.Fatal("linkRes.entries.entry(0) not found")
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" {
t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
@@ -1634,7 +1575,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
select {
case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
}
@@ -1643,31 +1584,24 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
func TestNeighborCacheRetryResolution(t *testing.T) {
config := DefaultNUDConfigurations()
clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(nil, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- // Simulate a faulty link.
- dropReplies: true,
- }
+ linkRes := newTestNeighborResolver(nil, config, clock)
+ // Simulate a faulty link.
+ linkRes.dropReplies = true
- entry, ok := store.entry(0)
+ entry, ok := linkRes.entries.entry(0)
if !ok {
- t.Fatal("store.entry(0) not found")
+ t.Fatal("linkRes.entries.entry(0) not found")
}
// Perform address resolution with a faulty link, which will fail.
{
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" {
t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
@@ -1675,7 +1609,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
select {
case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
}
@@ -1687,20 +1621,20 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
State: Failed,
},
}
- if diff := cmp.Diff(neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" {
+ if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" {
t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff)
}
// Retry address resolution with a working link.
linkRes.dropReplies = false
{
- incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ incompleteEntry, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
if incompleteEntry.State != Incomplete {
t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete)
@@ -1712,9 +1646,9 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
if !ok {
t.Fatal("expected successful address resolution")
}
- reachableEntry, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ reachableEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
if err != nil {
- t.Fatalf("neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err)
+ t.Fatalf("linkRes.neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err)
}
if reachableEntry.Addr != entry.Addr {
t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr)
@@ -1726,7 +1660,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String())
}
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
}
}
@@ -1735,42 +1669,36 @@ func BenchmarkCacheClear(b *testing.B) {
b.StopTimer()
config := DefaultNUDConfigurations()
clock := &tcpip.StdClock{}
- neigh := newTestNeighborCache(nil, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: 0,
- }
+ linkRes := newTestNeighborResolver(nil, config, clock)
+ linkRes.delay = 0
// Clear for every possible size of the cache
for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ {
// Fill the neighbor cache to capacity.
for i := 0; i < cacheSize; i++ {
- entry, ok := store.entry(i)
+ entry, ok := linkRes.entries.entry(i)
if !ok {
- b.Fatalf("store.entry(%d) not found", i)
+ b.Fatalf("linkRes.entries.entry(%d) not found", i)
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ b.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
select {
case <-ch:
default:
- b.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
}
b.StartTimer()
- neigh.clear()
+ linkRes.neigh.clear()
b.StopTimer()
}
}
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index a037ca6f9..b05f96d4f 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -77,11 +77,7 @@ const (
type neighborEntry struct {
neighborEntryEntry
- nic *NIC
-
- // linkRes provides the functionality to send reachability probes, used in
- // Neighbor Unreachability Detection.
- linkRes LinkAddressResolver
+ cache *neighborCache
// nudState points to the Neighbor Unreachability Detection configuration.
nudState *NUDState
@@ -106,10 +102,9 @@ type neighborEntry struct {
// state, Unknown. Transition out of Unknown by calling either
// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created
// neighborEntry.
-func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry {
+func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState *NUDState) *neighborEntry {
return &neighborEntry{
- nic: nic,
- linkRes: linkRes,
+ cache: cache,
nudState: nudState,
neigh: NeighborEntry{
Addr: remoteAddr,
@@ -121,18 +116,18 @@ func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, li
// newStaticNeighborEntry creates a neighbor cache entry starting at the
// Static state. The entry can only transition out of Static by directly
// calling `setStateLocked`.
-func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry {
+func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry {
entry := NeighborEntry{
Addr: addr,
LinkAddr: linkAddr,
State: Static,
- UpdatedAtNanos: nic.stack.clock.NowNanoseconds(),
+ UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(),
}
- if nic.stack.nudDisp != nil {
- nic.stack.nudDisp.OnNeighborAdded(nic.id, entry)
+ if nudDisp := cache.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborAdded(cache.nic.id, entry)
}
return &neighborEntry{
- nic: nic,
+ cache: cache,
nudState: state,
neigh: entry,
}
@@ -158,7 +153,7 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) {
// is resolved (which ends up obtaining the entry's lock) while holding the
// link resolution queue's lock. Dequeuing packets in a new goroutine avoids
// a lock ordering violation.
- go e.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded)
+ go e.cache.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded)
}
}
@@ -167,8 +162,8 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) {
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchAddEventLocked() {
- if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborAdded(e.nic.id, e.neigh)
+ if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborAdded(e.cache.nic.id, e.neigh)
}
}
@@ -177,8 +172,8 @@ func (e *neighborEntry) dispatchAddEventLocked() {
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchChangeEventLocked() {
- if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborChanged(e.nic.id, e.neigh)
+ if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborChanged(e.cache.nic.id, e.neigh)
}
}
@@ -187,8 +182,8 @@ func (e *neighborEntry) dispatchChangeEventLocked() {
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchRemoveEventLocked() {
- if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborRemoved(e.nic.id, e.neigh)
+ if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborRemoved(e.cache.nic.id, e.neigh)
}
}
@@ -206,7 +201,7 @@ func (e *neighborEntry) cancelJobLocked() {
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) removeLocked() {
- e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
+ e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds()
e.dispatchRemoveEventLocked()
e.cancelJobLocked()
e.notifyCompletionLocked(false /* succeeded */)
@@ -222,7 +217,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
prev := e.neigh.State
e.neigh.State = next
- e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
+ e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds()
config := e.nudState.Config()
switch next {
@@ -230,14 +225,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev))
case Reachable:
- e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.job = e.cache.nic.stack.newJob(&e.mu, func() {
e.setStateLocked(Stale)
e.dispatchChangeEventLocked()
})
e.job.Schedule(e.nudState.ReachableTime())
case Delay:
- e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.job = e.cache.nic.stack.newJob(&e.mu, func() {
e.setStateLocked(Probe)
e.dispatchChangeEventLocked()
})
@@ -254,14 +249,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr); err != nil {
+ if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr); err != nil {
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
}
retryCounter++
- e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe)
+ e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe)
e.job.Schedule(config.RetransmitTimer)
}
@@ -269,7 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
// for finishing the state transition. This is necessary to avoid
// deadlock where sending and processing probes are done synchronously,
// such as loopback and integration tests.
- e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe)
+ e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe)
e.job.Schedule(immediateDuration)
case Failed:
@@ -292,12 +287,12 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
switch e.neigh.State {
case Failed:
- e.nic.stats.Neighbor.FailedEntryLookups.Increment()
+ e.cache.nic.stats.Neighbor.FailedEntryLookups.Increment()
fallthrough
case Unknown:
e.neigh.State = Incomplete
- e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
+ e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds()
e.dispatchAddEventLocked()
@@ -340,7 +335,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
// address SHOULD be placed in the IP Source Address of the outgoing
// solicitation.
//
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, ""); err != nil {
+ if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, ""); err != nil {
// There is no need to log the error here; the NUD implementation may
// assume a working link. A valid link should be the responsibility of
// the NIC/stack.LinkEndpoint.
@@ -350,7 +345,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
}
retryCounter++
- e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
+ e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe)
e.job.Schedule(config.RetransmitTimer)
}
@@ -358,7 +353,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
// for finishing the state transition. This is necessary to avoid
// deadlock where sending and processing probes are done synchronously,
// such as loopback and integration tests.
- e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
+ e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe)
e.job.Schedule(immediateDuration)
case Stale:
@@ -504,7 +499,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
//
// TODO(gvisor.dev/issue/4085): Remove the special casing we do for IPv6
// here.
- ep, ok := e.nic.networkEndpoints[header.IPv6ProtocolNumber]
+ ep, ok := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber]
if !ok {
panic(fmt.Sprintf("have a neighbor entry for an IPv6 router but no IPv6 network endpoint"))
}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index 5e5e0e6ca..57cfbdb8b 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -230,23 +230,30 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
},
stats: makeNICStats(),
}
+ netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil)
nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{
- header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil),
+ header.IPv6ProtocolNumber: netEP,
}
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
nudState := NewNUDState(c, rng)
- linkRes := entryTestLinkResolver{}
- entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes)
-
+ var linkRes entryTestLinkResolver
// Stub out the neighbor cache to verify deletion from the cache.
neigh := &neighborCache{
- nic: &nic,
- state: nudState,
- cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ nic: &nic,
+ state: nudState,
+ linkRes: &linkRes,
+ cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ }
+ l := linkResolver{
+ resolver: &linkRes,
+ neighborTable: neigh,
}
+ entry := newNeighborEntry(neigh, entryTestAddr1 /* remoteAddr */, nudState)
neigh.cache[entryTestAddr1] = entry
- nic.neighborTable = neigh
+ nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]linkResolver{
+ header.IPv6ProtocolNumber: l,
+ }
return entry, &disp, &linkRes, clock
}
@@ -836,7 +843,7 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, clock := entryTestSetup(c)
- ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint)
+ ipv6EP := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index c813b0da5..693ea064a 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -27,11 +27,11 @@ import (
type neighborTable interface {
neighbors() ([]NeighborEntry, tcpip.Error)
addStaticEntry(tcpip.Address, tcpip.LinkAddress)
- get(addr tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error)
+ get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error)
remove(tcpip.Address) tcpip.Error
removeAll() tcpip.Error
- handleProbe(tcpip.Address, tcpip.LinkAddress, LinkAddressResolver)
+ handleProbe(tcpip.Address, tcpip.LinkAddress)
handleConfirmation(tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags)
handleUpperLevelConfirmation(tcpip.Address)
@@ -41,6 +41,20 @@ type neighborTable interface {
var _ NetworkInterface = (*NIC)(nil)
+type linkResolver struct {
+ resolver LinkAddressResolver
+
+ neighborTable neighborTable
+}
+
+func (l *linkResolver) getNeighborLinkAddress(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
+ return l.neighborTable.get(addr, localAddr, onResolve)
+}
+
+func (l *linkResolver) confirmReachable(addr tcpip.Address) {
+ l.neighborTable.handleUpperLevelConfirmation(addr)
+}
+
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
@@ -56,7 +70,7 @@ type NIC struct {
// The network endpoints themselves may be modified by calling the interface's
// methods, but the map reference and entries must be constant.
networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint
- linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+ linkAddrResolvers map[tcpip.NetworkProtocolNumber]linkResolver
// enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
//
@@ -67,8 +81,6 @@ type NIC struct {
// complete.
linkResQueue packetsPendingLinkResolution
- neighborTable neighborTable
-
mu struct {
sync.RWMutex
spoofing bool
@@ -153,25 +165,13 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
context: ctx,
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
- linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
+ linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]linkResolver),
}
nic.linkResQueue.init(nic)
nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0
- if resolutionRequired {
- if stack.useNeighborCache {
- nic.neighborTable = &neighborCache{
- nic: nic,
- state: NewNUDState(stack.nudConfigs, stack.randomGenerator),
- cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
- }
- } else {
- nic.neighborTable = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts)
- }
- }
-
// Register supported packet and network endpoint protocols.
for _, netProto := range header.Ethertypes {
nic.mu.packetEPs[netProto] = new(packetEndpointList)
@@ -185,7 +185,24 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
if resolutionRequired {
if r, ok := netEP.(LinkAddressResolver); ok {
- nic.linkAddrResolvers[r.LinkAddressProtocol()] = r
+ l := linkResolver{
+ resolver: r,
+ }
+
+ if stack.useNeighborCache {
+ l.neighborTable = &neighborCache{
+ nic: nic,
+ state: NewNUDState(stack.nudConfigs, stack.randomGenerator),
+ linkRes: r,
+
+ cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ }
+ } else {
+ cache := new(linkAddrCache)
+ cache.init(nic, ageLimit, resolutionTimeout, resolutionAttempts, r)
+ l.neighborTable = cache
+ }
+ nic.linkAddrResolvers[r.LinkAddressProtocol()] = l
}
}
}
@@ -240,18 +257,19 @@ func (n *NIC) disableLocked() {
for _, ep := range n.networkEndpoints {
ep.Disable()
- }
- // Clear the neighbour table (including static entries) as we cannot guarantee
- // that the current neighbour table will be valid when the NIC is enabled
- // again.
- //
- // This matches linux's behaviour at the time of writing:
- // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371
- switch err := n.clearNeighbors(); err.(type) {
- case nil, *tcpip.ErrNotSupported:
- default:
- panic(fmt.Sprintf("n.clearNeighbors(): %s", err))
+ // Clear the neighbour table (including static entries) as we cannot
+ // guarantee that the current neighbour table will be valid when the NIC is
+ // enabled again.
+ //
+ // This matches linux's behaviour at the time of writing:
+ // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371
+ netProto := ep.NetworkProtocolNumber()
+ switch err := n.clearNeighbors(netProto); err.(type) {
+ case nil, *tcpip.ErrNotSupported:
+ default:
+ panic(fmt.Sprintf("n.clearNeighbors(%d): %s", netProto, err))
+ }
}
if !n.setEnabled(false) {
@@ -604,63 +622,49 @@ func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error {
return &tcpip.ErrBadLocalAddress{}
}
-func (n *NIC) confirmReachable(addr tcpip.Address) {
- if n.neighborTable != nil {
- n.neighborTable.handleUpperLevelConfirmation(addr)
- }
-}
-
func (n *NIC) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error {
linkRes, ok := n.linkAddrResolvers[protocol]
if !ok {
return &tcpip.ErrNotSupported{}
}
- if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok {
+ if linkAddr, ok := linkRes.resolver.ResolveStaticAddress(addr); ok {
onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true})
return nil
}
- _, _, err := n.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve)
+ _, _, err := linkRes.getNeighborLinkAddress(addr, localAddr, onResolve)
return err
}
-func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
- if n.neighborTable != nil {
- return n.neighborTable.get(addr, linkRes, localAddr, onResolve)
- }
-
- return "", nil, &tcpip.ErrNotSupported{}
-}
-
-func (n *NIC) neighbors() ([]NeighborEntry, tcpip.Error) {
- if n.neighborTable != nil {
- return n.neighborTable.neighbors()
+func (n *NIC) neighbors(protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) {
+ if linkRes, ok := n.linkAddrResolvers[protocol]; ok {
+ return linkRes.neighborTable.neighbors()
}
return nil, &tcpip.ErrNotSupported{}
}
-func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) tcpip.Error {
- if n.neighborTable != nil {
- n.neighborTable.addStaticEntry(addr, linkAddress)
+func (n *NIC) addStaticNeighbor(addr tcpip.Address, protocol tcpip.NetworkProtocolNumber, linkAddress tcpip.LinkAddress) tcpip.Error {
+ if linkRes, ok := n.linkAddrResolvers[protocol]; ok {
+ linkRes.neighborTable.addStaticEntry(addr, linkAddress)
return nil
}
return &tcpip.ErrNotSupported{}
}
-func (n *NIC) removeNeighbor(addr tcpip.Address) tcpip.Error {
- if n.neighborTable != nil {
- return n.neighborTable.remove(addr)
+func (n *NIC) removeNeighbor(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
+ if linkRes, ok := n.linkAddrResolvers[protocol]; ok {
+ return linkRes.neighborTable.remove(addr)
}
return &tcpip.ErrNotSupported{}
}
-func (n *NIC) clearNeighbors() tcpip.Error {
- if n.neighborTable != nil {
- return n.neighborTable.removeAll()
+func (n *NIC) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error {
+ if linkRes, ok := n.linkAddrResolvers[protocol]; ok {
+ return linkRes.neighborTable.removeAll()
}
return &tcpip.ErrNotSupported{}
@@ -947,9 +951,9 @@ func (n *NIC) Name() string {
}
// nudConfigs gets the NUD configurations for n.
-func (n *NIC) nudConfigs() (NUDConfigurations, tcpip.Error) {
- if n.neighborTable != nil {
- return n.neighborTable.nudConfig()
+func (n *NIC) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) {
+ if linkRes, ok := n.linkAddrResolvers[protocol]; ok {
+ return linkRes.neighborTable.nudConfig()
}
return NUDConfigurations{}, &tcpip.ErrNotSupported{}
@@ -959,10 +963,10 @@ func (n *NIC) nudConfigs() (NUDConfigurations, tcpip.Error) {
//
// Note, if c contains invalid NUD configuration values, it will be fixed to
// use default values for the erroneous values.
-func (n *NIC) setNUDConfigs(c NUDConfigurations) tcpip.Error {
- if n.neighborTable != nil {
+func (n *NIC) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error {
+ if linkRes, ok := n.linkAddrResolvers[protocol]; ok {
c.resetInvalidFields()
- return n.neighborTable.setNUDConfig(c)
+ return linkRes.neighborTable.setNUDConfig(c)
}
return &tcpip.ErrNotSupported{}
@@ -1003,15 +1007,21 @@ func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool {
}
// HandleNeighborProbe implements NetworkInterface.
-func (n *NIC) HandleNeighborProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) {
- if n.neighborTable != nil {
- n.neighborTable.handleProbe(addr, linkAddr, linkRes)
+func (n *NIC) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error {
+ if l, ok := n.linkAddrResolvers[protocol]; ok {
+ l.neighborTable.handleProbe(addr, linkAddr)
+ return nil
}
+
+ return &tcpip.ErrNotSupported{}
}
// HandleNeighborConfirmation implements NetworkInterface.
-func (n *NIC) HandleNeighborConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
- if n.neighborTable != nil {
- n.neighborTable.handleConfirmation(addr, linkAddr, flags)
+func (n *NIC) HandleNeighborConfirmation(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) tcpip.Error {
+ if l, ok := n.linkAddrResolvers[protocol]; ok {
+ l.neighborTable.handleConfirmation(addr, linkAddr, flags)
+ return nil
}
+
+ return &tcpip.ErrNotSupported{}
}
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
index 504acc246..e9acef6a2 100644
--- a/pkg/tcpip/stack/nud_test.go
+++ b/pkg/tcpip/stack/nud_test.go
@@ -19,7 +19,9 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -52,66 +54,146 @@ func (f *fakeRand) Float32() float32 {
return f.num
}
-// TestSetNUDConfigurationFailsForBadNICID tests to make sure we get an error if
-// we attempt to update NUD configurations using an invalid NICID.
-func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) {
- s := stack.New(stack.Options{
- // A neighbor cache is required to store NUDConfigurations. The networking
- // stack will only allocate neighbor caches if a protocol providing link
- // address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- UseNeighborCache: true,
- })
+func TestNUDFunctions(t *testing.T) {
+ const nicID = 1
- // No NIC with ID 1 yet.
- config := stack.NUDConfigurations{}
- err := s.SetNUDConfigurations(1, config)
- if _, ok := err.(*tcpip.ErrUnknownNICID); !ok {
- t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, &tcpip.ErrUnknownNICID{})
+ tests := []struct {
+ name string
+ nicID tcpip.NICID
+ netProtoFactory []stack.NetworkProtocolFactory
+ extraLinkCapabilities stack.LinkEndpointCapabilities
+ expectedErr tcpip.Error
+ }{
+ {
+ name: "Invalid NICID",
+ nicID: nicID + 1,
+ netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ extraLinkCapabilities: stack.CapabilityResolutionRequired,
+ expectedErr: &tcpip.ErrUnknownNICID{},
+ },
+ {
+ name: "No network protocol",
+ nicID: nicID,
+ expectedErr: &tcpip.ErrNotSupported{},
+ },
+ {
+ name: "With IPv6",
+ nicID: nicID,
+ netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ expectedErr: &tcpip.ErrNotSupported{},
+ },
+ {
+ name: "With resolution capability",
+ nicID: nicID,
+ extraLinkCapabilities: stack.CapabilityResolutionRequired,
+ expectedErr: &tcpip.ErrNotSupported{},
+ },
+ {
+ name: "With IPv6 and resolution capability",
+ nicID: nicID,
+ netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ extraLinkCapabilities: stack.CapabilityResolutionRequired,
+ },
}
-}
-// TestNUDConfigurationFailsForNotSupported tests to make sure we get a
-// NotSupported error if we attempt to retrieve or set NUD configurations when
-// the stack doesn't support NUD.
-//
-// The stack will report to not support NUD if a neighbor cache for a given NIC
-// is not allocated. The networking stack will only allocate neighbor caches if
-// the NIC requires link resolution.
-func TestNUDConfigurationFailsForNotSupported(t *testing.T) {
- const nicID = 1
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ UseNeighborCache: true,
+ NetworkProtocols: test.netProtoFactory,
+ Clock: clock,
+ })
- e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities &^= stack.CapabilityResolutionRequired
+ e := channel.New(0, 0, linkAddr1)
+ e.LinkEPCapabilities &^= stack.CapabilityResolutionRequired
+ e.LinkEPCapabilities |= test.extraLinkCapabilities
- s := stack.New(stack.Options{
- NUDConfigs: stack.DefaultNUDConfigurations(),
- UseNeighborCache: true,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- t.Run("Get", func(t *testing.T) {
- _, err := s.NUDConfigurations(nicID)
- if _, ok := err.(*tcpip.ErrNotSupported); !ok {
- t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, &tcpip.ErrNotSupported{})
- }
- })
+ configs := stack.DefaultNUDConfigurations()
+ configs.BaseReachableTime = time.Hour
- t.Run("Set", func(t *testing.T) {
- config := stack.NUDConfigurations{}
- err := s.SetNUDConfigurations(nicID, config)
- if _, ok := err.(*tcpip.ErrNotSupported); !ok {
- t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, &tcpip.ErrNotSupported{})
- }
- })
+ {
+ err := s.SetNUDConfigurations(test.nicID, ipv6.ProtocolNumber, configs)
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Errorf("s.SetNUDConfigurations(%d, %d, _) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff)
+ }
+ }
+
+ {
+ gotConfigs, err := s.NUDConfigurations(test.nicID, ipv6.ProtocolNumber)
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Errorf("s.NUDConfigurations(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff)
+ } else if test.expectedErr == nil {
+ if diff := cmp.Diff(configs, gotConfigs); diff != "" {
+ t.Errorf("got configs mismatch (-want +got):\n%s", diff)
+ }
+ }
+ }
+
+ for _, addr := range []tcpip.Address{llAddr1, llAddr2} {
+ {
+ err := s.AddStaticNeighbor(test.nicID, ipv6.ProtocolNumber, addr, linkAddr1)
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Errorf("s.AddStaticNeighbor(%d, %d, %s, %s) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, addr, linkAddr1, diff)
+ }
+ }
+ }
+
+ {
+ wantErr := test.expectedErr
+ for i := 0; i < 2; i++ {
+ {
+ err := s.RemoveNeighbor(test.nicID, ipv6.ProtocolNumber, llAddr1)
+ if diff := cmp.Diff(wantErr, err); diff != "" {
+ t.Errorf("s.RemoveNeighbor(%d, %d, '') error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff)
+ }
+ }
+
+ if test.expectedErr != nil {
+ break
+ }
+
+ // Removing a neighbor that does not exist should give us a bad address
+ // error.
+ wantErr = &tcpip.ErrBadAddress{}
+ }
+ }
+
+ {
+ neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber)
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff)
+ } else if test.expectedErr == nil {
+ if diff := cmp.Diff(
+ []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}},
+ neighbors,
+ ); diff != "" {
+ t.Errorf("neighbors mismatch (-want +got):\n%s", diff)
+ }
+ }
+ }
+
+ {
+ err := s.ClearNeighbors(test.nicID, ipv6.ProtocolNumber)
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Errorf("s.ClearNeigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff)
+ } else if test.expectedErr == nil {
+ if neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber); err != nil {
+ t.Errorf("s.Neighbors(%d, %d): %s", test.nicID, ipv6.ProtocolNumber, err)
+ } else if len(neighbors) != 0 {
+ t.Errorf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors)
+ }
+ }
+ }
+ })
+ }
}
-// TestDefaultNUDConfigurationIsValid verifies that calling
-// resetInvalidFields() on the result of DefaultNUDConfigurations() does not
-// change anything. DefaultNUDConfigurations() should return a valid
-// NUDConfigurations.
func TestDefaultNUDConfigurations(t *testing.T) {
const nicID = 1
@@ -129,12 +211,12 @@ func TestDefaultNUDConfigurations(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- c, err := s.NUDConfigurations(nicID)
+ c, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
}
if got, want := c, stack.DefaultNUDConfigurations(); got != want {
- t.Errorf("got stack.NUDConfigurations(%d) = %+v, want = %+v", nicID, got, want)
+ t.Errorf("got stack.NUDConfigurations(%d, %d) = %+v, want = %+v", nicID, ipv6.ProtocolNumber, got, want)
}
}
@@ -184,9 +266,9 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- sc, err := s.NUDConfigurations(nicID)
+ sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
}
if got := sc.BaseReachableTime; got != test.want {
t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want)
@@ -241,9 +323,9 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- sc, err := s.NUDConfigurations(nicID)
+ sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
}
if got := sc.MinRandomFactor; got != test.want {
t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want)
@@ -321,9 +403,9 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- sc, err := s.NUDConfigurations(nicID)
+ sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
}
if got := sc.MaxRandomFactor; got != test.want {
t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want)
@@ -383,9 +465,9 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- sc, err := s.NUDConfigurations(nicID)
+ sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
}
if got := sc.RetransmitTimer; got != test.want {
t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want)
@@ -435,9 +517,9 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- sc, err := s.NUDConfigurations(nicID)
+ sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
}
if got := sc.DelayFirstProbeTime; got != test.want {
t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want)
@@ -487,9 +569,9 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- sc, err := s.NUDConfigurations(nicID)
+ sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
}
if got := sc.MaxMulticastProbes; got != test.want {
t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want)
@@ -539,9 +621,9 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- sc, err := s.NUDConfigurations(nicID)
+ sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
}
if got := sc.MaxUnicastProbes; got != test.want {
t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want)
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index c652c2bd7..e02f7190c 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -536,11 +536,11 @@ type NetworkInterface interface {
//
// HandleNeighborProbe assumes that the probe is valid for the network
// interface the probe was received on.
- HandleNeighborProbe(tcpip.Address, tcpip.LinkAddress, LinkAddressResolver)
+ HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error
// HandleNeighborConfirmation processes an incoming neighbor confirmation
// (e.g. ARP reply or NDP Neighbor Advertisement).
- HandleNeighborConfirmation(tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags)
+ HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) tcpip.Error
}
// LinkResolvableNetworkEndpoint handles link resolution events.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 1c8ef6ed4..bab55ce49 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -53,7 +53,7 @@ type Route struct {
// linkRes is set if link address resolution is enabled for this protocol on
// the route's NIC.
- linkRes LinkAddressResolver
+ linkRes linkResolver
}
type routeInfo struct {
@@ -184,11 +184,11 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA
return r
}
- if r.linkRes == nil {
+ if r.linkRes.resolver == nil {
return r
}
- if linkAddr, ok := r.linkRes.ResolveStaticAddress(r.RemoteAddress); ok {
+ if linkAddr, ok := r.linkRes.resolver.ResolveStaticAddress(r.RemoteAddress); ok {
r.ResolveWith(linkAddr)
return r
}
@@ -362,7 +362,7 @@ func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteIn
}
afterResolveFields := fields
- linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, r.linkRes, func(r LinkResolutionResult) {
+ linkAddr, ch, err := r.linkRes.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, func(r LinkResolutionResult) {
if afterResolve != nil {
if r.Success {
afterResolveFields.RemoteLinkAddress = r.LinkAddress
@@ -400,7 +400,7 @@ func (r *Route) IsResolutionRequired() bool {
}
func (r *Route) isResolutionRequiredRLocked() bool {
- return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local()
+ return len(r.mu.remoteLinkAddress) == 0 && r.linkRes.resolver != nil && r.isValidForOutgoingRLocked() && !r.local()
}
func (r *Route) isValidForOutgoing() bool {
@@ -528,5 +528,7 @@ func (r *Route) IsOutboundBroadcast() bool {
// "Reachable" is defined as having full-duplex communication between the
// local and remote ends of the route.
func (r *Route) ConfirmReachable() {
- r.outgoingNIC.confirmReachable(r.nextHop())
+ if r.linkRes.resolver != nil {
+ r.linkRes.confirmReachable(r.nextHop())
+ }
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 73db6e031..9390aaf57 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1560,7 +1560,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
}
// Neighbors returns all IP to MAC address associations.
-func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, tcpip.Error) {
+func (s *Stack) Neighbors(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
@@ -1569,11 +1569,11 @@ func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, tcpip.Error) {
return nil, &tcpip.ErrUnknownNICID{}
}
- return nic.neighbors()
+ return nic.neighbors(protocol)
}
// AddStaticNeighbor statically associates an IP address to a MAC address.
-func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error {
+func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
@@ -1582,13 +1582,13 @@ func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAdd
return &tcpip.ErrUnknownNICID{}
}
- return nic.addStaticNeighbor(addr, linkAddr)
+ return nic.addStaticNeighbor(addr, protocol, linkAddr)
}
// RemoveNeighbor removes an IP to MAC address association previously created
// either automically or by AddStaticNeighbor. Returns ErrBadAddress if there
// is no association with the provided address.
-func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) tcpip.Error {
+func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
@@ -1597,11 +1597,11 @@ func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) tcpip.Erro
return &tcpip.ErrUnknownNICID{}
}
- return nic.removeNeighbor(addr)
+ return nic.removeNeighbor(protocol, addr)
}
// ClearNeighbors removes all IP to MAC address associations.
-func (s *Stack) ClearNeighbors(nicID tcpip.NICID) tcpip.Error {
+func (s *Stack) ClearNeighbors(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber) tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
@@ -1610,7 +1610,7 @@ func (s *Stack) ClearNeighbors(nicID tcpip.NICID) tcpip.Error {
return &tcpip.ErrUnknownNICID{}
}
- return nic.clearNeighbors()
+ return nic.clearNeighbors(protocol)
}
// RegisterTransportEndpoint registers the given endpoint with the stack
@@ -1980,7 +1980,7 @@ func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtoco
}
// NUDConfigurations gets the per-interface NUD configurations.
-func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, tcpip.Error) {
+func (s *Stack) NUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) {
s.mu.RLock()
nic, ok := s.nics[id]
s.mu.RUnlock()
@@ -1989,14 +1989,14 @@ func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, tcpip.Erro
return NUDConfigurations{}, &tcpip.ErrUnknownNICID{}
}
- return nic.nudConfigs()
+ return nic.nudConfigs(proto)
}
// SetNUDConfigurations sets the per-interface NUD configurations.
//
// Note, if c contains invalid NUD configuration values, it will be fixed to
// use default values for the erroneous values.
-func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) tcpip.Error {
+func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[id]
s.mu.RUnlock()
@@ -2005,7 +2005,7 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) tcpip.
return &tcpip.ErrUnknownNICID{}
}
- return nic.setNUDConfigs(c)
+ return nic.setNUDConfigs(proto, c)
}
// Seed returns a 32 bit value that can be used as a seed value for port
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index a166c0502..375cd3080 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -4313,9 +4314,11 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) {
linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
UseNeighborCache: true,
+ Clock: clock,
})
e := channel.New(0, 0, "")
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
@@ -4323,36 +4326,56 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddStaticNeighbor(nicID, ipv4Addr, linkAddr); err != nil {
- t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv4Addr, linkAddr, err)
- }
- if err := s.AddStaticNeighbor(nicID, ipv6Addr, linkAddr); err != nil {
- t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv6Addr, linkAddr, err)
+ addrs := []struct {
+ proto tcpip.NetworkProtocolNumber
+ addr tcpip.Address
+ }{
+ {
+ proto: ipv4.ProtocolNumber,
+ addr: ipv4Addr,
+ },
+ {
+ proto: ipv6.ProtocolNumber,
+ addr: ipv6Addr,
+ },
}
- if neighbors, err := s.Neighbors(nicID); err != nil {
- t.Fatalf("s.Neighbors(%d): %s", nicID, err)
- } else if len(neighbors) != 2 {
- t.Fatalf("got len(neighbors) = %d, want = 2; neighbors = %#v", len(neighbors), neighbors)
+ for _, addr := range addrs {
+ if err := s.AddStaticNeighbor(nicID, addr.proto, addr.addr, linkAddr); err != nil {
+ t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, addr.proto, addr.addr, linkAddr, err)
+ }
+
+ if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil {
+ t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err)
+ } else if diff := cmp.Diff(
+ []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}},
+ neighbors,
+ ); diff != "" {
+ t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff)
+ }
}
// Disabling the NIC should clear the neighbor table.
if err := s.DisableNIC(nicID); err != nil {
t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
}
- if neighbors, err := s.Neighbors(nicID); err != nil {
- t.Fatalf("s.Neighbors(%d): %s", nicID, err)
- } else if len(neighbors) != 0 {
- t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors)
+ for _, addr := range addrs {
+ if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil {
+ t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err)
+ } else if len(neighbors) != 0 {
+ t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors)
+ }
}
// Enabling the NIC should have an empty neighbor table.
if err := s.EnableNIC(nicID); err != nil {
t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
}
- if neighbors, err := s.Neighbors(nicID); err != nil {
- t.Fatalf("s.Neighbors(%d): %s", nicID, err)
- } else if len(neighbors) != 0 {
- t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors)
+ for _, addr := range addrs {
+ if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil {
+ t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err)
+ } else if len(neighbors) != 0 {
+ t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors)
+ }
}
}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index 7069352f2..b3a5d49d7 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -1069,9 +1069,9 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
// Wait for the remote's neighbor entry to be stale before creating a
// TCP connection from host1 to some remote.
- nudConfigs, err := host1Stack.NUDConfigurations(host1NICID)
+ nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
if err != nil {
- t.Fatalf("host1Stack.NUDConfigurations(%d): %s", host1NICID, err)
+ t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
}
// The maximum reachable time for a neighbor is some maximum random factor
// applied to the base reachable time.