diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache_test.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 12 |
6 files changed, 63 insertions, 37 deletions
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 04b8f251a..3a147a75f 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -88,12 +88,14 @@ type linkAddrEntry struct { linkAddr tcpip.LinkAddress expiration time.Time s entryState + resDone bool // wakers is a set of waiters for address resolution result. Anytime // state transitions out of 'incomplete' these waiters are notified. wakers map[*sleep.Waker]struct{} cancel chan struct{} + resCh chan struct{} } func (e *linkAddrEntry) state() entryState { @@ -182,15 +184,20 @@ func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress // someone waiting for address resolution on it. entry.changeState(expired) if entry.cancel != nil { - entry.cancel <- struct{}{} + if !entry.resDone { + close(entry.resCh) + } + close(entry.cancel) } *entry = linkAddrEntry{ addr: k, linkAddr: v, expiration: time.Now().Add(c.ageLimit), + resDone: false, wakers: make(map[*sleep.Waker]struct{}), cancel: make(chan struct{}, 1), + resCh: make(chan struct{}, 1), } c.cache[k] = entry @@ -202,10 +209,10 @@ func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if linkRes != nil { if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { - return addr, nil + return addr, nil, nil } } @@ -214,10 +221,11 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo if entry == nil || entry.state() == expired { c.mu.Unlock() if linkRes == nil { - return "", tcpip.ErrNoLinkAddress + return "", nil, tcpip.ErrNoLinkAddress } - c.startAddressResolution(k, linkRes, localAddr, linkEP, waker) - return "", tcpip.ErrWouldBlock + + ch := c.startAddressResolution(k, linkRes, localAddr, linkEP, waker) + return "", ch, tcpip.ErrWouldBlock } defer c.mu.Unlock() @@ -227,13 +235,13 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo // in that case it's safe to consider it ready. fallthrough case ready: - return entry.linkAddr, nil + return entry.linkAddr, nil, nil case failed: - return "", tcpip.ErrNoLinkAddress + return "", nil, tcpip.ErrNoLinkAddress case incomplete: // Address resolution is still in progress. entry.addWaker(waker) - return "", tcpip.ErrWouldBlock + return "", entry.resCh, tcpip.ErrWouldBlock default: panic(fmt.Sprintf("invalid cache entry state: %d", s)) } @@ -249,13 +257,13 @@ func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { } } -func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) { +func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) <-chan struct{} { c.mu.Lock() defer c.mu.Unlock() // Look up again with lock held to ensure entry wasn't added by someone else. if e := c.cache[k]; e != nil && e.state() != expired { - return + return nil } // Add 'incomplete' entry in the cache to mark that resolution is in progress. @@ -274,6 +282,15 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link select { case <-time.After(c.resolutionTimeout): if stop := c.checkLinkRequest(k, i); stop { + // If entry is evicted then resCh is already closed. + c.mu.Lock() + if e, ok := c.cache[k]; ok { + if !e.resDone { + e.resDone = true + close(e.resCh) + } + } + c.mu.Unlock() return } case <-cancel: @@ -281,6 +298,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link } } }() + return e.resCh } // checkLinkRequest checks whether previous attempt to resolve address has succeeded diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index f0988d6de..e46267f12 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -73,7 +73,7 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe defer s.Done() for { - if got, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { + if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { return got, err } s.Fetch(true) @@ -95,7 +95,7 @@ func TestCacheOverflow(t *testing.T) { for i := len(testaddrs) - 1; i >= 0; i-- { e := testaddrs[i] c.add(e.addr, e.linkAddr) - got, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) } @@ -106,7 +106,7 @@ func TestCacheOverflow(t *testing.T) { // Expect to find at least half of the most recent entries. for i := 0; i < linkAddrCacheSize/2; i++ { e := testaddrs[i] - got, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) } @@ -117,7 +117,7 @@ func TestCacheOverflow(t *testing.T) { // The earliest entries should no longer be in the cache. for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- { e := testaddrs[i] - if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { + if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) } } @@ -143,7 +143,7 @@ func TestCacheConcurrent(t *testing.T) { // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. e := testaddrs[len(testaddrs)-1] - got, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) } @@ -152,7 +152,7 @@ func TestCacheConcurrent(t *testing.T) { } e = testaddrs[0] - if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { + if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } } @@ -162,7 +162,7 @@ func TestCacheAgeLimit(t *testing.T) { e := testaddrs[0] c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { + if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } } @@ -172,7 +172,7 @@ func TestCacheReplace(t *testing.T) { e := testaddrs[0] l2 := e.linkAddr + "2" c.add(e.addr, e.linkAddr) - got, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) } @@ -181,7 +181,7 @@ func TestCacheReplace(t *testing.T) { } c.add(e.addr, l2) - got, err = c.get(e.addr, nil, "", nil, nil) + got, _, err = c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) } @@ -206,7 +206,7 @@ func TestCacheResolution(t *testing.T) { // Check that after resolved, address stays in the cache and never returns WouldBlock. for i := 0; i < 10; i++ { e := testaddrs[len(testaddrs)-1] - got, err := c.get(e.addr, linkRes, "", nil, nil) + got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) } @@ -256,7 +256,7 @@ func TestStaticResolution(t *testing.T) { addr := tcpip.Address("broadcast") want := tcpip.LinkAddress("mac_broadcast") - got, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil) + got, _, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err) } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 595c7e793..0acec2984 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -289,7 +289,11 @@ type LinkAddressCache interface { // registered with the network protocol, the cache attempts to resolve the address // and returns ErrWouldBlock. Waker is notified when address resolution is // complete (success or not). - GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) + // + // If address resolution is required, ErrNoLinkAddress and a notification channel is + // returned for the top level caller to block. Channel is closed once address resolution + // is complete (success or not). + GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) // RemoveWaker removes a waker that has been added in GetLinkAddress(). RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index cc9b24e23..6c6400c33 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -89,11 +89,15 @@ func (r *Route) Capabilities() LinkEndpointCapabilities { // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in // case address resolution requires blocking, e.g. wait for ARP reply. Waker is // notified when address resolution is complete (success or not). -func (r *Route) Resolve(waker *sleep.Waker) *tcpip.Error { +// +// If address resolution is required, ErrNoLinkAddress and a notification channel is +// returned for the top level caller to block. Channel is closed once address resolution +// is complete (success or not). +func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if !r.IsResolutionRequired() { // Nothing to do if there is no cache (which does the resolution on cache miss) or // link address is already known. - return nil + return nil, nil } nextAddr := r.NextHop @@ -101,16 +105,16 @@ func (r *Route) Resolve(waker *sleep.Waker) *tcpip.Error { // Local link address is already known. if r.RemoteAddress == r.LocalAddress { r.RemoteLinkAddress = r.LocalLinkAddress - return nil + return nil, nil } nextAddr = r.RemoteAddress } - linkAddr, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) + linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) if err != nil { - return err + return ch, err } r.RemoteLinkAddress = linkAddr - return nil + return nil, nil } // RemoveWaker removes a waker that has been added in Resolve(). diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 699519be1..d1ec6a660 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -831,12 +831,12 @@ func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr t } // GetLinkAddress implements LinkAddressCache.GetLinkAddress. -func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { s.mu.RLock() nic := s.nics[nicid] if nic == nil { s.mu.RUnlock() - return "", tcpip.ErrUnknownNICID + return "", nil, tcpip.ErrUnknownNICID } s.mu.RUnlock() diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 9ec37e7b6..98cc3b120 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -60,21 +60,21 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr return buffer.View{}, tcpip.ControlMessages{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { - return 0, tcpip.ErrNoRoute + return 0, nil, tcpip.ErrNoRoute } hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) v, err := p.Get(p.Size()) if err != nil { - return 0, err + return 0, nil, err } if err := f.route.WritePacket(hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil { - return 0, err + return 0, nil, err } - return uintptr(len(v)), nil + return uintptr(len(v)), nil, nil } func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { @@ -362,7 +362,7 @@ func TestTransportSend(t *testing.T) { // Create buffer that will hold the payload. view := buffer.NewView(30) - _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) + _, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) if err != nil { t.Fatalf("write failed: %v", err) } |