diff options
author | Sepehr Raissian <sepehrtheraiss@gmail.com> | 2018-09-28 10:59:21 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-09-28 11:00:16 -0700 |
commit | c17ea8c6e20f58510b063f064d45608792a014e4 (patch) | |
tree | 272920bd1036a776e3e7c852c2c5795dc4f2a805 /pkg/tcpip/stack | |
parent | cf226d48ce8c49409049e03ed405366db9fc2a04 (diff) |
Block for link address resolution
Previously, if address resolution for UDP or Ping sockets required sending
packets using Write in Transport layer, Resolve would return ErrWouldBlock
and Write would return ErrNoLinkAddress. Meanwhile startAddressResolution
would run in background. Further calls to Write using same address would also
return ErrNoLinkAddress until resolution has been completed successfully.
Since Write is not allowed to block and System Calls need to be
interruptible in System Call layer, the caller to Write is responsible for
blocking upon return of ErrWouldBlock.
Now, when startAddressResolution is called a notification channel for
the completion of the address resolution is returned.
The channel will traverse up to the calling function of Write as well as
ErrNoLinkAddress. Once address resolution is complete (success or not) the
channel is closed. The caller would call Write again to send packets and
check if address resolution was compeleted successfully or not.
Fixes google/gvisor#5
Change-Id: Idafaf31982bee1915ca084da39ae7bd468cebd93
PiperOrigin-RevId: 214962200
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache_test.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 12 |
6 files changed, 63 insertions, 37 deletions
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 04b8f251a..3a147a75f 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -88,12 +88,14 @@ type linkAddrEntry struct { linkAddr tcpip.LinkAddress expiration time.Time s entryState + resDone bool // wakers is a set of waiters for address resolution result. Anytime // state transitions out of 'incomplete' these waiters are notified. wakers map[*sleep.Waker]struct{} cancel chan struct{} + resCh chan struct{} } func (e *linkAddrEntry) state() entryState { @@ -182,15 +184,20 @@ func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress // someone waiting for address resolution on it. entry.changeState(expired) if entry.cancel != nil { - entry.cancel <- struct{}{} + if !entry.resDone { + close(entry.resCh) + } + close(entry.cancel) } *entry = linkAddrEntry{ addr: k, linkAddr: v, expiration: time.Now().Add(c.ageLimit), + resDone: false, wakers: make(map[*sleep.Waker]struct{}), cancel: make(chan struct{}, 1), + resCh: make(chan struct{}, 1), } c.cache[k] = entry @@ -202,10 +209,10 @@ func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if linkRes != nil { if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { - return addr, nil + return addr, nil, nil } } @@ -214,10 +221,11 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo if entry == nil || entry.state() == expired { c.mu.Unlock() if linkRes == nil { - return "", tcpip.ErrNoLinkAddress + return "", nil, tcpip.ErrNoLinkAddress } - c.startAddressResolution(k, linkRes, localAddr, linkEP, waker) - return "", tcpip.ErrWouldBlock + + ch := c.startAddressResolution(k, linkRes, localAddr, linkEP, waker) + return "", ch, tcpip.ErrWouldBlock } defer c.mu.Unlock() @@ -227,13 +235,13 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo // in that case it's safe to consider it ready. fallthrough case ready: - return entry.linkAddr, nil + return entry.linkAddr, nil, nil case failed: - return "", tcpip.ErrNoLinkAddress + return "", nil, tcpip.ErrNoLinkAddress case incomplete: // Address resolution is still in progress. entry.addWaker(waker) - return "", tcpip.ErrWouldBlock + return "", entry.resCh, tcpip.ErrWouldBlock default: panic(fmt.Sprintf("invalid cache entry state: %d", s)) } @@ -249,13 +257,13 @@ func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { } } -func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) { +func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) <-chan struct{} { c.mu.Lock() defer c.mu.Unlock() // Look up again with lock held to ensure entry wasn't added by someone else. if e := c.cache[k]; e != nil && e.state() != expired { - return + return nil } // Add 'incomplete' entry in the cache to mark that resolution is in progress. @@ -274,6 +282,15 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link select { case <-time.After(c.resolutionTimeout): if stop := c.checkLinkRequest(k, i); stop { + // If entry is evicted then resCh is already closed. + c.mu.Lock() + if e, ok := c.cache[k]; ok { + if !e.resDone { + e.resDone = true + close(e.resCh) + } + } + c.mu.Unlock() return } case <-cancel: @@ -281,6 +298,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link } } }() + return e.resCh } // checkLinkRequest checks whether previous attempt to resolve address has succeeded diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index f0988d6de..e46267f12 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -73,7 +73,7 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe defer s.Done() for { - if got, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { + if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { return got, err } s.Fetch(true) @@ -95,7 +95,7 @@ func TestCacheOverflow(t *testing.T) { for i := len(testaddrs) - 1; i >= 0; i-- { e := testaddrs[i] c.add(e.addr, e.linkAddr) - got, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) } @@ -106,7 +106,7 @@ func TestCacheOverflow(t *testing.T) { // Expect to find at least half of the most recent entries. for i := 0; i < linkAddrCacheSize/2; i++ { e := testaddrs[i] - got, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) } @@ -117,7 +117,7 @@ func TestCacheOverflow(t *testing.T) { // The earliest entries should no longer be in the cache. for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- { e := testaddrs[i] - if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { + if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) } } @@ -143,7 +143,7 @@ func TestCacheConcurrent(t *testing.T) { // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. e := testaddrs[len(testaddrs)-1] - got, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) } @@ -152,7 +152,7 @@ func TestCacheConcurrent(t *testing.T) { } e = testaddrs[0] - if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { + if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } } @@ -162,7 +162,7 @@ func TestCacheAgeLimit(t *testing.T) { e := testaddrs[0] c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { + if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } } @@ -172,7 +172,7 @@ func TestCacheReplace(t *testing.T) { e := testaddrs[0] l2 := e.linkAddr + "2" c.add(e.addr, e.linkAddr) - got, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) } @@ -181,7 +181,7 @@ func TestCacheReplace(t *testing.T) { } c.add(e.addr, l2) - got, err = c.get(e.addr, nil, "", nil, nil) + got, _, err = c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) } @@ -206,7 +206,7 @@ func TestCacheResolution(t *testing.T) { // Check that after resolved, address stays in the cache and never returns WouldBlock. for i := 0; i < 10; i++ { e := testaddrs[len(testaddrs)-1] - got, err := c.get(e.addr, linkRes, "", nil, nil) + got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) } @@ -256,7 +256,7 @@ func TestStaticResolution(t *testing.T) { addr := tcpip.Address("broadcast") want := tcpip.LinkAddress("mac_broadcast") - got, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil) + got, _, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err) } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 595c7e793..0acec2984 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -289,7 +289,11 @@ type LinkAddressCache interface { // registered with the network protocol, the cache attempts to resolve the address // and returns ErrWouldBlock. Waker is notified when address resolution is // complete (success or not). - GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) + // + // If address resolution is required, ErrNoLinkAddress and a notification channel is + // returned for the top level caller to block. Channel is closed once address resolution + // is complete (success or not). + GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) // RemoveWaker removes a waker that has been added in GetLinkAddress(). RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index cc9b24e23..6c6400c33 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -89,11 +89,15 @@ func (r *Route) Capabilities() LinkEndpointCapabilities { // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in // case address resolution requires blocking, e.g. wait for ARP reply. Waker is // notified when address resolution is complete (success or not). -func (r *Route) Resolve(waker *sleep.Waker) *tcpip.Error { +// +// If address resolution is required, ErrNoLinkAddress and a notification channel is +// returned for the top level caller to block. Channel is closed once address resolution +// is complete (success or not). +func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if !r.IsResolutionRequired() { // Nothing to do if there is no cache (which does the resolution on cache miss) or // link address is already known. - return nil + return nil, nil } nextAddr := r.NextHop @@ -101,16 +105,16 @@ func (r *Route) Resolve(waker *sleep.Waker) *tcpip.Error { // Local link address is already known. if r.RemoteAddress == r.LocalAddress { r.RemoteLinkAddress = r.LocalLinkAddress - return nil + return nil, nil } nextAddr = r.RemoteAddress } - linkAddr, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) + linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) if err != nil { - return err + return ch, err } r.RemoteLinkAddress = linkAddr - return nil + return nil, nil } // RemoveWaker removes a waker that has been added in Resolve(). diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 699519be1..d1ec6a660 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -831,12 +831,12 @@ func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr t } // GetLinkAddress implements LinkAddressCache.GetLinkAddress. -func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { s.mu.RLock() nic := s.nics[nicid] if nic == nil { s.mu.RUnlock() - return "", tcpip.ErrUnknownNICID + return "", nil, tcpip.ErrUnknownNICID } s.mu.RUnlock() diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 9ec37e7b6..98cc3b120 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -60,21 +60,21 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr return buffer.View{}, tcpip.ControlMessages{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { - return 0, tcpip.ErrNoRoute + return 0, nil, tcpip.ErrNoRoute } hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) v, err := p.Get(p.Size()) if err != nil { - return 0, err + return 0, nil, err } if err := f.route.WritePacket(hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil { - return 0, err + return 0, nil, err } - return uintptr(len(v)), nil + return uintptr(len(v)), nil, nil } func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { @@ -362,7 +362,7 @@ func TestTransportSend(t *testing.T) { // Create buffer that will hold the payload. view := buffer.NewView(30) - _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) + _, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) if err != nil { t.Fatalf("write failed: %v", err) } |