diff options
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 24 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache_test.go | 64 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 21 |
10 files changed, 85 insertions, 73 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 1d4d2966e..9255a4f6a 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -147,7 +147,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) if e.nud == nil { - e.linkAddrCache.AddLinkAddress(e.nic.ID(), remoteAddr, remoteLinkAddr) + e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr) } else { e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol) } @@ -191,7 +191,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) if e.nud == nil { - e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) + e.linkAddrCache.AddLinkAddress(addr, linkAddr) return } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 47e8aa11a..ae5179d93 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -290,7 +290,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } else if e.nud != nil { e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(e.nic.ID(), srcAddr, sourceLinkAddr) + e.linkAddrCache.AddLinkAddress(srcAddr, sourceLinkAddr) } // As per RFC 4861 section 7.1.1: @@ -445,7 +445,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // address cache with the link address for the target of the message. if e.nud == nil { if len(targetLinkAddr) != 0 { - e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr) + e.linkAddrCache.AddLinkAddress(targetAddr, targetLinkAddr) } return } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index a660a1cea..defea46b0 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -95,7 +95,7 @@ var _ stack.LinkAddressCache = (*stubLinkAddressCache)(nil) type stubLinkAddressCache struct{} -func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) {} +func (*stubLinkAddressCache) AddLinkAddress(tcpip.Address, tcpip.LinkAddress) {} type stubNUDHandler struct { probeCount int diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 9f2fd8181..d29c9a49b 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -368,10 +368,6 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC UseNeighborCache: useNeighborCache, }) - if !useNeighborCache { - proto.addrCache = s.linkAddrCache - } - // Enable forwarding. s.SetForwarding(proto.Number(), true) @@ -401,13 +397,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC t.Fatal("AddAddress #2 failed:", err) } + nic, ok := s.nics[2] + if !ok { + t.Fatal("NIC 2 does not exist") + } if useNeighborCache { // Control the neighbor cache for NIC 2. - nic, ok := s.nics[2] - if !ok { - t.Fatal("failed to get the neighbor cache for NIC 2") - } proto.neigh = nic.neigh + } else { + proto.addrCache = nic.linkAddrCache } // Route all packets to NIC 2. @@ -493,7 +491,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any address will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, @@ -619,7 +617,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // Only packets to address 3 will be resolved to the // link address "c". if addr == "\x03" { - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") } }, }, @@ -704,7 +702,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, @@ -780,7 +778,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, @@ -870,7 +868,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index b600a1cab..3c4fa341e 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -24,6 +24,8 @@ import ( const linkAddrCacheSize = 512 // max cache entries +var _ LinkAddressCache = (*linkAddrCache)(nil) + // linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses. // // The entries are stored in a ring buffer, oldest entry replaced first. @@ -43,7 +45,7 @@ type linkAddrCache struct { cache struct { sync.Mutex - table map[tcpip.FullAddress]*linkAddrEntry + table map[tcpip.Address]*linkAddrEntry lru linkAddrEntryList } } @@ -81,7 +83,7 @@ type linkAddrEntry struct { // mu protects the fields below. mu sync.RWMutex - addr tcpip.FullAddress + addr tcpip.Address linkAddr tcpip.LinkAddress expiration time.Time s entryState @@ -125,7 +127,7 @@ func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { } // add adds a k -> v mapping to the cache. -func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { +func (c *linkAddrCache) AddLinkAddress(k tcpip.Address, v tcpip.LinkAddress) { // Calculate expiration time before acquiring the lock, since expiration is // relative to the time when information was learned, rather than when it // happened to be inserted into the cache. @@ -150,7 +152,7 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { // reset to state incomplete, and returned. If no matching entry exists and the // cache is not full, a new entry with state incomplete is allocated and // returned. -func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry { +func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { if entry, ok := c.cache.table[k]; ok { c.cache.lru.Remove(entry) c.cache.lru.PushFront(entry) @@ -181,7 +183,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { c.cache.Lock() defer c.cache.Unlock() entry := c.getOrCreateEntryLocked(k) @@ -214,11 +216,11 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo } } -func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic) + linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */, nic) select { case now := <-time.After(c.resolutionTimeout): @@ -234,7 +236,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link // checkLinkRequest checks whether previous attempt to resolve address has // succeeded and mark the entry accordingly. Returns true if request can stop, // false if another request should be sent. -func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool { +func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt int) bool { c.cache.Lock() defer c.cache.Unlock() entry, ok := c.cache.table[k] @@ -268,6 +270,6 @@ func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttem resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, } - c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize) + c.cache.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize) return c } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index d7ac6cf5f..8c35067c6 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -26,7 +26,7 @@ import ( ) type testaddr struct { - addr tcpip.FullAddress + addr tcpip.Address linkAddr tcpip.LinkAddress } @@ -35,7 +35,7 @@ var testAddrs = func() []testaddr { for i := 0; i < 4*linkAddrCacheSize; i++ { addr := fmt.Sprintf("Addr%06d", i) addrs = append(addrs, testaddr{ - addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, + addr: tcpip.Address(addr), linkAddr: tcpip.LinkAddress("Link" + addr), }) } @@ -59,8 +59,8 @@ func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { for _, ta := range testAddrs { - if ta.addr.Addr == addr { - r.cache.add(ta.addr, ta.linkAddr) + if ta.addr == addr { + r.cache.AddLinkAddress(ta.addr, ta.linkAddr) break } } @@ -77,7 +77,7 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe return 1 } -func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { +func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { var attemptedResolution bool for { got, ch, err := c.get(addr, linkRes, "", nil, nil) @@ -97,13 +97,13 @@ func TestCacheOverflow(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) for i := len(testAddrs) - 1; i >= 0; i-- { e := testAddrs[i] - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) + t.Errorf("insert %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) + t.Errorf("insert %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // Expect to find at least half of the most recent entries. @@ -111,10 +111,10 @@ func TestCacheOverflow(t *testing.T) { e := testAddrs[i] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) + t.Errorf("check %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) + t.Errorf("check %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // The earliest entries should no longer be in the cache. @@ -123,7 +123,7 @@ func TestCacheOverflow(t *testing.T) { for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { e := testAddrs[i] if entry, ok := c.cache.table[e.addr]; ok { - t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) + t.Errorf("unexpected entry at c.cache.table[%s]: %#v", e.addr, entry) } } } @@ -137,7 +137,7 @@ func TestCacheConcurrent(t *testing.T) { wg.Add(1) go func() { for _, e := range testAddrs { - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) } wg.Done() }() @@ -150,17 +150,17 @@ func TestCacheConcurrent(t *testing.T) { e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) } e = testAddrs[0] c.cache.Lock() defer c.cache.Unlock() if entry, ok := c.cache.table[e.addr]; ok { - t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) + t.Errorf("unexpected entry at c.cache.table[%s]: %#v", e.addr, entry) } } @@ -169,10 +169,10 @@ func TestCacheAgeLimit(t *testing.T) { linkRes := &testLinkAddressResolver{cache: c} e := testAddrs[0] - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err) + t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = ErrWouldBlock", e.addr, err) } } @@ -180,22 +180,22 @@ func TestCacheReplace(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) e := testAddrs[0] l2 := e.linkAddr + "2" - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) } - c.add(e.addr, l2) + c.AddLinkAddress(e.addr, l2) got, _, err = c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) } if got != l2 { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2) + t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, l2) } } @@ -211,10 +211,10 @@ func TestCacheResolution(t *testing.T) { for i, ta := range testAddrs { got, err := getBlocking(c, ta.addr, linkRes) if err != nil { - t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err) + t.Errorf("check %d, getBlocking(_, %s, _): %s", i, ta.addr, err) } if got != ta.linkAddr { - t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr) + t.Errorf("check %d, got getBlocking(_, %s, _) = %s, want = %s", i, ta.addr, got, ta.linkAddr) } } @@ -223,10 +223,10 @@ func TestCacheResolution(t *testing.T) { e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) } } } @@ -244,17 +244,17 @@ func TestCacheResolutionFailed(t *testing.T) { e := testAddrs[0] got, err := getBlocking(c, e.addr, linkRes) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("getBlocking(_, %s, _): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got getBlocking(_, %s, _) = %s, want = %s", e.addr, got, e.linkAddr) } before := atomic.LoadUint32(&requestCount) - e.addr.Addr += "2" + e.addr += "2" if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { - t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) + t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) } if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { @@ -270,6 +270,6 @@ func TestCacheResolutionTimeout(t *testing.T) { e := testAddrs[0] if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { - t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) + t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 61636cae5..270f5fb1a 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -2808,6 +2808,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -2827,10 +2828,15 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN Gateway: llAddr3, NIC: nicID, }}) + if useNeighborCache { - s.AddStaticNeighbor(nicID, llAddr3, linkAddr3) + if err := s.AddStaticNeighbor(nicID, llAddr3, linkAddr3); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) + } } else { - s.AddLinkAddress(nicID, llAddr3, linkAddr3) + if err := s.AddLinkAddress(nicID, llAddr3, linkAddr3); err != nil { + t.Fatalf("s.AddLinkAddress(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) + } } return ndpDisp, e, s } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0f545f255..f2bca93d3 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -53,6 +53,8 @@ type NIC struct { // complete. linkResQueue packetsPendingLinkResolution + linkAddrCache *linkAddrCache + mu struct { sync.RWMutex spoofing bool @@ -137,6 +139,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC context: ctx, stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), + linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), } nic.linkResQueue.init() nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) @@ -167,7 +170,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC for _, netProto := range stack.networkProtocols { netNum := netProto.Number() nic.mu.packetEPs[netNum] = new(packetEndpointList) - nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) + nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, nic.linkAddrCache, nud, nic) } nic.LinkEndpoint.Attach(nic) @@ -558,7 +561,7 @@ func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes Link return entry.LinkAddr, ch, err } - return n.stack.linkAddrCache.get(tcpip.FullAddress{NIC: n.ID(), Addr: addr}, linkRes, localAddr, n, onResolve) + return n.linkAddrCache.get(addr, linkRes, localAddr, n, onResolve) } func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 34c122728..33df192aa 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -850,7 +850,7 @@ type LinkAddressResolver interface { // A LinkAddressCache caches link addresses. type LinkAddressCache interface { // AddLinkAddress adds a link address to the cache. - AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) + AddLinkAddress(addr tcpip.Address, linkAddr tcpip.LinkAddress) } // RawFactory produces endpoints for writing various types of raw packets. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index b4878669c..4685fa4cf 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -382,8 +382,6 @@ type Stack struct { stats tcpip.Stats - linkAddrCache *linkAddrCache - mu sync.RWMutex nics map[tcpip.NICID]*NIC @@ -636,7 +634,6 @@ func New(opts Options) *Stack { linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), nics: make(map[tcpip.NICID]*NIC), cleanupEndpoints: make(map[TransportEndpoint]struct{}), - linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), PortManager: ports.NewPortManager(), clock: clock, stats: opts.Stats.FillIn(), @@ -1516,12 +1513,18 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { return nil } -// AddLinkAddress adds a link address to the stack link cache. -func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} - s.linkAddrCache.add(fullAddr, linkAddr) - // TODO: provide a way for a transport endpoint to receive a signal - // that AddLinkAddress for a particular address has been called. +// AddLinkAddress adds a link address for the neighbor on the specified NIC. +func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[nicID] + if !ok { + return tcpip.ErrUnknownNICID + } + + nic.linkAddrCache.AddLinkAddress(neighbor, linkAddr) + return nil } // GetLinkAddress finds the link address corresponding to a neighbor's address. |