diff options
-rw-r--r-- | pkg/tcpip/header/ipv6.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/BUILD | 14 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache.go | 253 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache_test.go | 79 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 19 |
7 files changed, 200 insertions, 189 deletions
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 31be42ce0..bc4e56535 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -27,7 +27,7 @@ const ( nextHdr = 6 hopLimit = 7 v6SrcAddr = 8 - v6DstAddr = 24 + v6DstAddr = v6SrcAddr + IPv6AddressSize ) // IPv6Fields contains the fields of an IPv6 packet. It is used to describe the @@ -119,13 +119,13 @@ func (b IPv6) Payload() []byte { // SourceAddress returns the "source address" field of the ipv6 header. func (b IPv6) SourceAddress() tcpip.Address { - return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize]) + return tcpip.Address(b[v6SrcAddr:][:IPv6AddressSize]) } // DestinationAddress returns the "destination address" field of the ipv6 // header. func (b IPv6) DestinationAddress() tcpip.Address { - return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize]) + return tcpip.Address(b[v6DstAddr:][:IPv6AddressSize]) } // Checksum implements Network.Checksum. Given that IPv6 doesn't have a @@ -153,13 +153,13 @@ func (b IPv6) SetPayloadLength(payloadLength uint16) { // SetSourceAddress sets the "source address" field of the ipv6 header. func (b IPv6) SetSourceAddress(addr tcpip.Address) { - copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr) + copy(b[v6SrcAddr:][:IPv6AddressSize], addr) } // SetDestinationAddress sets the "destination address" field of the ipv6 // header. func (b IPv6) SetDestinationAddress(addr tcpip.Address) { - copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr) + copy(b[v6DstAddr:][:IPv6AddressSize], addr) } // SetNextHeader sets the value of the "next header" field of the ipv6 header. @@ -178,8 +178,8 @@ func (b IPv6) Encode(i *IPv6Fields) { b.SetPayloadLength(i.PayloadLength) b[nextHdr] = i.NextHeader b[hopLimit] = i.HopLimit - copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr) - copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr) + b.SetSourceAddress(i.SrcAddr) + b.SetDestinationAddress(i.DstAddr) } // IsValid performs basic validation on the packet. diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index ea7296e6a..fd6395fc1 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -112,11 +112,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender()) copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) - fallthrough // also fill the cache from requests case header.ARPReply: - addr := tcpip.Address(h.ProtocolAddressSender()) - linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr) } } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 5e6a59e91..1689af16f 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -100,13 +100,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) - if len(v) < header.ICMPv6NeighborSolicitMinimumSize { received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8:][:16]) + targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize]) if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 { // We don't have a useful answer; the best we can do is ignore the request. return @@ -146,7 +144,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8:][:16]) + targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize]) e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress) if targetAddr != r.RemoteAddress { e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 9986b4be3..ebbcea601 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,11 +1,25 @@ package(licenses = ["notice"]) +load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +go_template_instance( + name = "linkaddrentry_list", + out = "linkaddrentry_list.go", + package = "stack", + prefix = "linkAddrEntry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*linkAddrEntry", + "Linker": "*linkAddrEntry", + }, +) + go_library( name = "stack", srcs = [ "linkaddrcache.go", + "linkaddrentry_list.go", "nic.go", "registration.go", "route.go", diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 77bb0ccb9..267df60d1 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -42,10 +42,11 @@ type linkAddrCache struct { // resolved before failing. resolutionAttempts int - mu sync.Mutex - cache map[tcpip.FullAddress]*linkAddrEntry - next int // array index of next available entry - entries [linkAddrCacheSize]linkAddrEntry + cache struct { + sync.Mutex + table map[tcpip.FullAddress]*linkAddrEntry + lru linkAddrEntryList + } } // entryState controls the state of a single entry in the cache. @@ -60,9 +61,6 @@ const ( // failed means that address resolution timed out and the address // could not be resolved. failed - // expired means that the cache entry has expired and the address must be - // resolved again. - expired ) // String implements Stringer. @@ -74,8 +72,6 @@ func (s entryState) String() string { return "ready" case failed: return "failed" - case expired: - return "expired" default: return fmt.Sprintf("unknown(%d)", s) } @@ -84,64 +80,46 @@ func (s entryState) String() string { // A linkAddrEntry is an entry in the linkAddrCache. // This struct is thread-compatible. type linkAddrEntry struct { + linkAddrEntryEntry + addr tcpip.FullAddress linkAddr tcpip.LinkAddress expiration time.Time s entryState // wakers is a set of waiters for address resolution result. Anytime - // state transitions out of 'incomplete' these waiters are notified. + // state transitions out of incomplete these waiters are notified. wakers map[*sleep.Waker]struct{} + // done is used to allow callers to wait on address resolution. It is nil iff + // s is incomplete and resolution is not yet in progress. done chan struct{} } -func (e *linkAddrEntry) state() entryState { - if e.s != expired && time.Now().After(e.expiration) { - // Force the transition to ensure waiters are notified. - e.changeState(expired) - } - return e.s -} - -func (e *linkAddrEntry) changeState(ns entryState) { - if e.s == ns { - return - } - - // Validate state transition. - switch e.s { - case incomplete: - // All transitions are valid. - case ready, failed: - if ns != expired { - panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns)) - } - case expired: - // Terminal state. - panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns)) - default: - panic(fmt.Sprintf("invalid state: %s", e.s)) - } - +// changeState sets the entry's state to ns, notifying any waiters. +// +// The entry's expiration is bumped up to the greater of itself and the passed +// expiration; the zero value indicates immediate expiration, and is set +// unconditionally - this is an implementation detail that allows for entries +// to be reused. +func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { // Notify whoever is waiting on address resolution when transitioning - // out of 'incomplete'. - if e.s == incomplete { + // out of incomplete. + if e.s == incomplete && ns != incomplete { for w := range e.wakers { w.Assert() } e.wakers = nil - if e.done != nil { - close(e.done) + if ch := e.done; ch != nil { + close(ch) } + e.done = nil } - e.s = ns -} -func (e *linkAddrEntry) maybeAddWaker(w *sleep.Waker) { - if w != nil { - e.wakers[w] = struct{}{} + if expiration.IsZero() || expiration.After(e.expiration) { + e.expiration = expiration } + e.s = ns } func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { @@ -150,53 +128,54 @@ func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { // add adds a k -> v mapping to the cache. func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { - c.mu.Lock() - defer c.mu.Unlock() - - entry, ok := c.cache[k] - if ok { - s := entry.state() - if s != expired && entry.linkAddr == v { - // Disregard repeated calls. - return - } - // Check if entry is waiting for address resolution. - if s == incomplete { - entry.linkAddr = v - } else { - // Otherwise create a new entry to replace it. - entry = c.makeAndAddEntry(k, v) - } - } else { - entry = c.makeAndAddEntry(k, v) - } + // Calculate expiration time before acquiring the lock, since expiration is + // relative to the time when information was learned, rather than when it + // happened to be inserted into the cache. + expiration := time.Now().Add(c.ageLimit) - entry.changeState(ready) + c.cache.Lock() + entry := c.getOrCreateEntryLocked(k) + entry.linkAddr = v + + entry.changeState(ready, expiration) + c.cache.Unlock() } -// makeAndAddEntry is a helper function to create and add a new -// entry to the cache map and evict older entry as needed. -func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry { - // Take over the next entry. - entry := &c.entries[c.next] - if c.cache[entry.addr] == entry { - delete(c.cache, entry.addr) +// getOrCreateEntryLocked retrieves a cache entry associated with k. The +// returned entry is always refreshed in the cache (it is reachable via the +// map, and its place is bumped in LRU). +// +// If a matching entry exists in the cache, it is returned. If no matching +// entry exists and the cache is full, an existing entry is evicted via LRU, +// reset to state incomplete, and returned. If no matching entry exists and the +// cache is not full, a new entry with state incomplete is allocated and +// returned. +func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry { + if entry, ok := c.cache.table[k]; ok { + c.cache.lru.Remove(entry) + c.cache.lru.PushFront(entry) + return entry } + var entry *linkAddrEntry + if len(c.cache.table) == linkAddrCacheSize { + entry = c.cache.lru.Back() - // Mark the soon-to-be-replaced entry as expired, just in case there is - // someone waiting for address resolution on it. - entry.changeState(expired) + delete(c.cache.table, entry.addr) + c.cache.lru.Remove(entry) - *entry = linkAddrEntry{ - addr: k, - linkAddr: v, - expiration: time.Now().Add(c.ageLimit), - wakers: make(map[*sleep.Waker]struct{}), - done: make(chan struct{}), + // Wake waiters and mark the soon-to-be-reused entry as expired. Note + // that the state passed doesn't matter when the zero time is passed. + entry.changeState(failed, time.Time{}) + } else { + entry = new(linkAddrEntry) } - c.cache[k] = entry - c.next = (c.next + 1) % len(c.entries) + *entry = linkAddrEntry{ + addr: k, + s: incomplete, + } + c.cache.table[k] = entry + c.cache.lru.PushFront(entry) return entry } @@ -208,43 +187,55 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo } } - c.mu.Lock() - defer c.mu.Unlock() - if entry, ok := c.cache[k]; ok { - switch s := entry.state(); s { - case expired: - case ready: - return entry.linkAddr, nil, nil - case failed: - return "", nil, tcpip.ErrNoLinkAddress - case incomplete: - // Address resolution is still in progress. - entry.maybeAddWaker(waker) - return "", entry.done, tcpip.ErrWouldBlock - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) + c.cache.Lock() + defer c.cache.Unlock() + entry := c.getOrCreateEntryLocked(k) + switch s := entry.s; s { + case ready, failed: + if !time.Now().After(entry.expiration) { + // Not expired. + switch s { + case ready: + return entry.linkAddr, nil, nil + case failed: + return entry.linkAddr, nil, tcpip.ErrNoLinkAddress + default: + panic(fmt.Sprintf("invalid cache entry state: %s", s)) + } } - } - if linkRes == nil { - return "", nil, tcpip.ErrNoLinkAddress - } + entry.changeState(incomplete, time.Time{}) + fallthrough + case incomplete: + if waker != nil { + if entry.wakers == nil { + entry.wakers = make(map[*sleep.Waker]struct{}) + } + entry.wakers[waker] = struct{}{} + } - // Add 'incomplete' entry in the cache to mark that resolution is in progress. - e := c.makeAndAddEntry(k, "") - e.maybeAddWaker(waker) + if entry.done == nil { + // Address resolution needs to be initiated. + if linkRes == nil { + return entry.linkAddr, nil, tcpip.ErrNoLinkAddress + } - go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + entry.done = make(chan struct{}) + go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + } - return "", e.done, tcpip.ErrWouldBlock + return entry.linkAddr, entry.done, tcpip.ErrWouldBlock + default: + panic(fmt.Sprintf("invalid cache entry state: %s", s)) + } } // removeWaker removes a waker previously added through get(). func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { - c.mu.Lock() - defer c.mu.Unlock() + c.cache.Lock() + defer c.cache.Unlock() - if entry, ok := c.cache[k]; ok { + if entry, ok := c.cache.table[k]; ok { entry.removeWaker(waker) } } @@ -256,8 +247,8 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP) select { - case <-time.After(c.resolutionTimeout): - if stop := c.checkLinkRequest(k, i); stop { + case now := <-time.After(c.resolutionTimeout): + if stop := c.checkLinkRequest(now, k, i); stop { return } case <-done: @@ -269,38 +260,36 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link // checkLinkRequest checks whether previous attempt to resolve address has succeeded // and mark the entry accordingly, e.g. ready, failed, etc. Return true if request // can stop, false if another request should be sent. -func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool { - c.mu.Lock() - defer c.mu.Unlock() - - entry, ok := c.cache[k] +func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool { + c.cache.Lock() + defer c.cache.Unlock() + entry, ok := c.cache.table[k] if !ok { // Entry was evicted from the cache. return true } - - switch s := entry.state(); s { - case ready, failed, expired: + switch s := entry.s; s { + case ready, failed: // Entry was made ready by resolver or failed. Either way we're done. - return true case incomplete: - if attempt+1 >= c.resolutionAttempts { - // Max number of retries reached, mark entry as failed. - entry.changeState(failed) - return true + if attempt+1 < c.resolutionAttempts { + // No response yet, need to send another ARP request. + return false } - // No response yet, need to send another ARP request. - return false + // Max number of retries reached, mark entry as failed. + entry.changeState(failed, now.Add(c.ageLimit)) default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } + return true } func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { - return &linkAddrCache{ + c := &linkAddrCache{ ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, - cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize), } + c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize) + return c } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 924f4d240..9946b8fe8 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -17,6 +17,7 @@ package stack import ( "fmt" "sync" + "sync/atomic" "testing" "time" @@ -29,25 +30,34 @@ type testaddr struct { linkAddr tcpip.LinkAddress } -var testaddrs []testaddr +var testAddrs = func() []testaddr { + var addrs []testaddr + for i := 0; i < 4*linkAddrCacheSize; i++ { + addr := fmt.Sprintf("Addr%06d", i) + addrs = append(addrs, testaddr{ + addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, + linkAddr: tcpip.LinkAddress("Link" + addr), + }) + } + return addrs +}() type testLinkAddressResolver struct { - cache *linkAddrCache - delay time.Duration + cache *linkAddrCache + delay time.Duration + onLinkAddressRequest func() } func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { - go func() { - if r.delay > 0 { - time.Sleep(r.delay) - } - r.fakeRequest(addr) - }() + time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) + if f := r.onLinkAddressRequest; f != nil { + f() + } return nil } func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { - for _, ta := range testaddrs { + for _, ta := range testAddrs { if ta.addr.Addr == addr { r.cache.add(ta.addr, ta.linkAddr) break @@ -80,20 +90,10 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe } } -func init() { - for i := 0; i < 4*linkAddrCacheSize; i++ { - addr := fmt.Sprintf("Addr%06d", i) - testaddrs = append(testaddrs, testaddr{ - addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, - linkAddr: tcpip.LinkAddress("Link" + addr), - }) - } -} - func TestCacheOverflow(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - for i := len(testaddrs) - 1; i >= 0; i-- { - e := testaddrs[i] + for i := len(testAddrs) - 1; i >= 0; i-- { + e := testAddrs[i] c.add(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { @@ -105,7 +105,7 @@ func TestCacheOverflow(t *testing.T) { } // Expect to find at least half of the most recent entries. for i := 0; i < linkAddrCacheSize/2; i++ { - e := testaddrs[i] + e := testAddrs[i] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) @@ -115,8 +115,8 @@ func TestCacheOverflow(t *testing.T) { } } // The earliest entries should no longer be in the cache. - for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- { - e := testaddrs[i] + for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { + e := testAddrs[i] if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) } @@ -130,7 +130,7 @@ func TestCacheConcurrent(t *testing.T) { for r := 0; r < 16; r++ { wg.Add(1) go func() { - for _, e := range testaddrs { + for _, e := range testAddrs { c.add(e.addr, e.linkAddr) c.get(e.addr, nil, "", nil, nil) // make work for gotsan } @@ -142,7 +142,7 @@ func TestCacheConcurrent(t *testing.T) { // All goroutines add in the same order and add more values than // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. - e := testaddrs[len(testaddrs)-1] + e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -151,7 +151,7 @@ func TestCacheConcurrent(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) } - e = testaddrs[0] + e = testAddrs[0] if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } @@ -159,7 +159,7 @@ func TestCacheConcurrent(t *testing.T) { func TestCacheAgeLimit(t *testing.T) { c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) - e := testaddrs[0] + e := testAddrs[0] c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { @@ -169,7 +169,7 @@ func TestCacheAgeLimit(t *testing.T) { func TestCacheReplace(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - e := testaddrs[0] + e := testAddrs[0] l2 := e.linkAddr + "2" c.add(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) @@ -193,7 +193,7 @@ func TestCacheReplace(t *testing.T) { func TestCacheResolution(t *testing.T) { c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1) linkRes := &testLinkAddressResolver{cache: c} - for i, ta := range testaddrs { + for i, ta := range testAddrs { got, err := getBlocking(c, ta.addr, linkRes) if err != nil { t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err) @@ -205,7 +205,7 @@ func TestCacheResolution(t *testing.T) { // Check that after resolved, address stays in the cache and never returns WouldBlock. for i := 0; i < 10; i++ { - e := testaddrs[len(testaddrs)-1] + e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -220,8 +220,13 @@ func TestCacheResolutionFailed(t *testing.T) { c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5) linkRes := &testLinkAddressResolver{cache: c} + var requestCount uint32 + linkRes.onLinkAddressRequest = func() { + atomic.AddUint32(&requestCount, 1) + } + // First, sanity check that resolution is working... - e := testaddrs[0] + e := testAddrs[0] got, err := getBlocking(c, e.addr, linkRes) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -230,10 +235,16 @@ func TestCacheResolutionFailed(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) } + before := atomic.LoadUint32(&requestCount) + e.addr.Addr += "2" if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } + + if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { + t.Errorf("got link address request count = %d, want = %d", got, want) + } } func TestCacheResolutionTimeout(t *testing.T) { @@ -242,7 +253,7 @@ func TestCacheResolutionTimeout(t *testing.T) { c := newLinkAddrCache(expiration, 1*time.Millisecond, 3) linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} - e := testaddrs[0] + e := testAddrs[0] if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 04b63d783..89b4c5960 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -531,6 +531,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return nil } +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) { + r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) + r.RemoteLinkAddress = remotelinkAddr + ref.ep.HandlePacket(&r, vv) + ref.decRef() +} + // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when // the NIC receives a packet from the physical interface. @@ -558,6 +565,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr src, dst := netProto.ParseAddresses(vv.First()) + n.stack.AddLinkAddress(n.id, src, remote) + // If the packet is destined to the IPv4 Broadcast address, then make a // route to each IPv4 network endpoint and let each endpoint handle the // packet. @@ -566,10 +575,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n.mu.RLock() for _, ref := range n.endpoints { if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) - r.RemoteLinkAddress = remote - ref.ep.HandlePacket(&r, vv) - ref.decRef() + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) } } n.mu.RUnlock() @@ -577,10 +583,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr } if ref := n.getRef(protocol, dst); ref != nil { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) - r.RemoteLinkAddress = remote - ref.ep.HandlePacket(&r, vv) - ref.decRef() + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) return } |