summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-01-19 16:54:48 -0800
committergVisor bot <gvisor-bot@google.com>2021-01-19 16:56:49 -0800
commit7ff5ceaeae66303ed6a2199963c00cb08b2fe7ca (patch)
tree5b2ab07f55394c2106a4cbd71046cc069b681f23
parent48ea2c34d1d3dead7727d9e2760b587c7609b14b (diff)
Do not have a stack-wide linkAddressCache
Link addresses are cached on a per NIC basis so instead of having a single cache that includes the NIC ID for neighbor entry lookups, use a single cache per NIC. PiperOrigin-RevId: 352684111
-rw-r--r--pkg/tcpip/network/arp/arp.go4
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go4
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go2
-rw-r--r--pkg/tcpip/stack/forwarding_test.go24
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go20
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go64
-rw-r--r--pkg/tcpip/stack/ndp_test.go10
-rw-r--r--pkg/tcpip/stack/nic.go7
-rw-r--r--pkg/tcpip/stack/registration.go2
-rw-r--r--pkg/tcpip/stack/stack.go21
10 files changed, 85 insertions, 73 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 1d4d2966e..9255a4f6a 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -147,7 +147,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.nic.ID(), remoteAddr, remoteLinkAddr)
+ e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr)
} else {
e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
}
@@ -191,7 +191,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(addr, linkAddr)
return
}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 47e8aa11a..ae5179d93 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -290,7 +290,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
} else if e.nud != nil {
e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
- e.linkAddrCache.AddLinkAddress(e.nic.ID(), srcAddr, sourceLinkAddr)
+ e.linkAddrCache.AddLinkAddress(srcAddr, sourceLinkAddr)
}
// As per RFC 4861 section 7.1.1:
@@ -445,7 +445,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// address cache with the link address for the target of the message.
if e.nud == nil {
if len(targetLinkAddr) != 0 {
- e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr)
+ e.linkAddrCache.AddLinkAddress(targetAddr, targetLinkAddr)
}
return
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index a660a1cea..defea46b0 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -95,7 +95,7 @@ var _ stack.LinkAddressCache = (*stubLinkAddressCache)(nil)
type stubLinkAddressCache struct{}
-func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) {}
+func (*stubLinkAddressCache) AddLinkAddress(tcpip.Address, tcpip.LinkAddress) {}
type stubNUDHandler struct {
probeCount int
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 9f2fd8181..d29c9a49b 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -368,10 +368,6 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC
UseNeighborCache: useNeighborCache,
})
- if !useNeighborCache {
- proto.addrCache = s.linkAddrCache
- }
-
// Enable forwarding.
s.SetForwarding(proto.Number(), true)
@@ -401,13 +397,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC
t.Fatal("AddAddress #2 failed:", err)
}
+ nic, ok := s.nics[2]
+ if !ok {
+ t.Fatal("NIC 2 does not exist")
+ }
if useNeighborCache {
// Control the neighbor cache for NIC 2.
- nic, ok := s.nics[2]
- if !ok {
- t.Fatal("failed to get the neighbor cache for NIC 2")
- }
proto.neigh = nic.neigh
+ } else {
+ proto.addrCache = nic.linkAddrCache
}
// Route all packets to NIC 2.
@@ -493,7 +491,7 @@ func TestForwardingWithFakeResolver(t *testing.T) {
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any address will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ cache.AddLinkAddress(addr, "c")
},
},
},
@@ -619,7 +617,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Only packets to address 3 will be resolved to the
// link address "c".
if addr == "\x03" {
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ cache.AddLinkAddress(addr, "c")
}
},
},
@@ -704,7 +702,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ cache.AddLinkAddress(addr, "c")
},
},
},
@@ -780,7 +778,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ cache.AddLinkAddress(addr, "c")
},
},
},
@@ -870,7 +868,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ cache.AddLinkAddress(addr, "c")
},
},
},
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index b600a1cab..3c4fa341e 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -24,6 +24,8 @@ import (
const linkAddrCacheSize = 512 // max cache entries
+var _ LinkAddressCache = (*linkAddrCache)(nil)
+
// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses.
//
// The entries are stored in a ring buffer, oldest entry replaced first.
@@ -43,7 +45,7 @@ type linkAddrCache struct {
cache struct {
sync.Mutex
- table map[tcpip.FullAddress]*linkAddrEntry
+ table map[tcpip.Address]*linkAddrEntry
lru linkAddrEntryList
}
}
@@ -81,7 +83,7 @@ type linkAddrEntry struct {
// mu protects the fields below.
mu sync.RWMutex
- addr tcpip.FullAddress
+ addr tcpip.Address
linkAddr tcpip.LinkAddress
expiration time.Time
s entryState
@@ -125,7 +127,7 @@ func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) {
}
// add adds a k -> v mapping to the cache.
-func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
+func (c *linkAddrCache) AddLinkAddress(k tcpip.Address, v tcpip.LinkAddress) {
// Calculate expiration time before acquiring the lock, since expiration is
// relative to the time when information was learned, rather than when it
// happened to be inserted into the cache.
@@ -150,7 +152,7 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
// reset to state incomplete, and returned. If no matching entry exists and the
// cache is not full, a new entry with state incomplete is allocated and
// returned.
-func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry {
+func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry {
if entry, ok := c.cache.table[k]; ok {
c.cache.lru.Remove(entry)
c.cache.lru.PushFront(entry)
@@ -181,7 +183,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
c.cache.Lock()
defer c.cache.Unlock()
entry := c.getOrCreateEntryLocked(k)
@@ -214,11 +216,11 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
}
}
-func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
+func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic)
+ linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */, nic)
select {
case now := <-time.After(c.resolutionTimeout):
@@ -234,7 +236,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
// checkLinkRequest checks whether previous attempt to resolve address has
// succeeded and mark the entry accordingly. Returns true if request can stop,
// false if another request should be sent.
-func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool {
+func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt int) bool {
c.cache.Lock()
defer c.cache.Unlock()
entry, ok := c.cache.table[k]
@@ -268,6 +270,6 @@ func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttem
resolutionTimeout: resolutionTimeout,
resolutionAttempts: resolutionAttempts,
}
- c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize)
+ c.cache.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize)
return c
}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index d7ac6cf5f..8c35067c6 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -26,7 +26,7 @@ import (
)
type testaddr struct {
- addr tcpip.FullAddress
+ addr tcpip.Address
linkAddr tcpip.LinkAddress
}
@@ -35,7 +35,7 @@ var testAddrs = func() []testaddr {
for i := 0; i < 4*linkAddrCacheSize; i++ {
addr := fmt.Sprintf("Addr%06d", i)
addrs = append(addrs, testaddr{
- addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
+ addr: tcpip.Address(addr),
linkAddr: tcpip.LinkAddress("Link" + addr),
})
}
@@ -59,8 +59,8 @@ func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address
func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
for _, ta := range testAddrs {
- if ta.addr.Addr == addr {
- r.cache.add(ta.addr, ta.linkAddr)
+ if ta.addr == addr {
+ r.cache.AddLinkAddress(ta.addr, ta.linkAddr)
break
}
}
@@ -77,7 +77,7 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe
return 1
}
-func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
+func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
var attemptedResolution bool
for {
got, ch, err := c.get(addr, linkRes, "", nil, nil)
@@ -97,13 +97,13 @@ func TestCacheOverflow(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
for i := len(testAddrs) - 1; i >= 0; i-- {
e := testAddrs[i]
- c.add(e.addr, e.linkAddr)
+ c.AddLinkAddress(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
- t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
+ t.Errorf("insert %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("insert %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
}
}
// Expect to find at least half of the most recent entries.
@@ -111,10 +111,10 @@ func TestCacheOverflow(t *testing.T) {
e := testAddrs[i]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
- t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
+ t.Errorf("check %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("check %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
}
}
// The earliest entries should no longer be in the cache.
@@ -123,7 +123,7 @@ func TestCacheOverflow(t *testing.T) {
for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
e := testAddrs[i]
if entry, ok := c.cache.table[e.addr]; ok {
- t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry)
+ t.Errorf("unexpected entry at c.cache.table[%s]: %#v", e.addr, entry)
}
}
}
@@ -137,7 +137,7 @@ func TestCacheConcurrent(t *testing.T) {
wg.Add(1)
go func() {
for _, e := range testAddrs {
- c.add(e.addr, e.linkAddr)
+ c.AddLinkAddress(e.addr, e.linkAddr)
}
wg.Done()
}()
@@ -150,17 +150,17 @@ func TestCacheConcurrent(t *testing.T) {
e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr)
}
e = testAddrs[0]
c.cache.Lock()
defer c.cache.Unlock()
if entry, ok := c.cache.table[e.addr]; ok {
- t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry)
+ t.Errorf("unexpected entry at c.cache.table[%s]: %#v", e.addr, entry)
}
}
@@ -169,10 +169,10 @@ func TestCacheAgeLimit(t *testing.T) {
linkRes := &testLinkAddressResolver{cache: c}
e := testAddrs[0]
- c.add(e.addr, e.linkAddr)
+ c.AddLinkAddress(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err)
+ t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = ErrWouldBlock", e.addr, err)
}
}
@@ -180,22 +180,22 @@ func TestCacheReplace(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
e := testAddrs[0]
l2 := e.linkAddr + "2"
- c.add(e.addr, e.linkAddr)
+ c.AddLinkAddress(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr)
}
- c.add(e.addr, l2)
+ c.AddLinkAddress(e.addr, l2)
got, _, err = c.get(e.addr, nil, "", nil, nil)
if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err)
}
if got != l2 {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2)
+ t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, l2)
}
}
@@ -211,10 +211,10 @@ func TestCacheResolution(t *testing.T) {
for i, ta := range testAddrs {
got, err := getBlocking(c, ta.addr, linkRes)
if err != nil {
- t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
+ t.Errorf("check %d, getBlocking(_, %s, _): %s", i, ta.addr, err)
}
if got != ta.linkAddr {
- t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr)
+ t.Errorf("check %d, got getBlocking(_, %s, _) = %s, want = %s", i, ta.addr, got, ta.linkAddr)
}
}
@@ -223,10 +223,10 @@ func TestCacheResolution(t *testing.T) {
e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr)
}
}
}
@@ -244,17 +244,17 @@ func TestCacheResolutionFailed(t *testing.T) {
e := testAddrs[0]
got, err := getBlocking(c, e.addr, linkRes)
if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ t.Errorf("getBlocking(_, %s, _): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("got getBlocking(_, %s, _) = %s, want = %s", e.addr, got, e.linkAddr)
}
before := atomic.LoadUint32(&requestCount)
- e.addr.Addr += "2"
+ e.addr += "2"
if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout {
- t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout)
+ t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout)
}
if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
@@ -270,6 +270,6 @@ func TestCacheResolutionTimeout(t *testing.T) {
e := testAddrs[0]
if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout {
- t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout)
+ t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout)
}
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 61636cae5..270f5fb1a 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -2808,6 +2808,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -2827,10 +2828,15 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN
Gateway: llAddr3,
NIC: nicID,
}})
+
if useNeighborCache {
- s.AddStaticNeighbor(nicID, llAddr3, linkAddr3)
+ if err := s.AddStaticNeighbor(nicID, llAddr3, linkAddr3); err != nil {
+ t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err)
+ }
} else {
- s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+ if err := s.AddLinkAddress(nicID, llAddr3, linkAddr3); err != nil {
+ t.Fatalf("s.AddLinkAddress(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err)
+ }
}
return ndpDisp, e, s
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 0f545f255..f2bca93d3 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -53,6 +53,8 @@ type NIC struct {
// complete.
linkResQueue packetsPendingLinkResolution
+ linkAddrCache *linkAddrCache
+
mu struct {
sync.RWMutex
spoofing bool
@@ -137,6 +139,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
context: ctx,
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
+ linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
}
nic.linkResQueue.init()
nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
@@ -167,7 +170,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
nic.mu.packetEPs[netNum] = new(packetEndpointList)
- nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
+ nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, nic.linkAddrCache, nud, nic)
}
nic.LinkEndpoint.Attach(nic)
@@ -558,7 +561,7 @@ func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes Link
return entry.LinkAddr, ch, err
}
- return n.stack.linkAddrCache.get(tcpip.FullAddress{NIC: n.ID(), Addr: addr}, linkRes, localAddr, n, onResolve)
+ return n.linkAddrCache.get(addr, linkRes, localAddr, n, onResolve)
}
func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) {
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 34c122728..33df192aa 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -850,7 +850,7 @@ type LinkAddressResolver interface {
// A LinkAddressCache caches link addresses.
type LinkAddressCache interface {
// AddLinkAddress adds a link address to the cache.
- AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
+ AddLinkAddress(addr tcpip.Address, linkAddr tcpip.LinkAddress)
}
// RawFactory produces endpoints for writing various types of raw packets.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index b4878669c..4685fa4cf 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -382,8 +382,6 @@ type Stack struct {
stats tcpip.Stats
- linkAddrCache *linkAddrCache
-
mu sync.RWMutex
nics map[tcpip.NICID]*NIC
@@ -636,7 +634,6 @@ func New(opts Options) *Stack {
linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
nics: make(map[tcpip.NICID]*NIC),
cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
PortManager: ports.NewPortManager(),
clock: clock,
stats: opts.Stats.FillIn(),
@@ -1516,12 +1513,18 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
return nil
}
-// AddLinkAddress adds a link address to the stack link cache.
-func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
- s.linkAddrCache.add(fullAddr, linkAddr)
- // TODO: provide a way for a transport endpoint to receive a signal
- // that AddLinkAddress for a particular address has been called.
+// AddLinkAddress adds a link address for the neighbor on the specified NIC.
+func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.linkAddrCache.AddLinkAddress(neighbor, linkAddr)
+ return nil
}
// GetLinkAddress finds the link address corresponding to a neighbor's address.