summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-01-21 16:37:35 -0800
committergVisor bot <gvisor-bot@google.com>2021-01-21 16:40:06 -0800
commit8ecff1890277820972c5f5287539a824b22a1d60 (patch)
tree5c3b3473dfa825feffee79ae1b47eda1a9e22c2b /pkg/tcpip/stack
parent48dfb8db9e784604e9c7ad8e1a36cc862dac1b4d (diff)
Do not cache remote link address in Route
...unless explicitly requested via ResolveWith. Remove cancelled channels from pending packets as we can use the link resolution channel in a FIFO to limit the number of maximum pending resolutions we should queue packets for. This change also defers starting the goroutine that handles link resolution completion to when link resolution succeeds, fails or gets cancelled due to the max number of pending resolutions being reached. Fixes #751. PiperOrigin-RevId: 353130577
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go20
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go20
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go8
-rw-r--r--pkg/tcpip/stack/nic.go47
-rw-r--r--pkg/tcpip/stack/pending_packets.go247
-rw-r--r--pkg/tcpip/stack/route.go67
6 files changed, 236 insertions, 173 deletions
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 3c4fa341e..f116f8417 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -32,6 +32,8 @@ var _ LinkAddressCache = (*linkAddrCache)(nil)
//
// This struct is safe for concurrent use.
type linkAddrCache struct {
+ nic *NIC
+
// ageLimit is how long a cache entry is valid for.
ageLimit time.Duration
@@ -79,6 +81,8 @@ type linkAddrEntry struct {
// linkAddrEntryEntry access is synchronized by the linkAddrCache lock.
linkAddrEntryEntry
+ cache *linkAddrCache
+
// TODO(gvisor.dev/issue/5150): move these fields under mu.
// mu protects the fields below.
mu sync.RWMutex
@@ -104,6 +108,14 @@ func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) {
if ch := e.done; ch != nil {
close(ch)
e.done = nil
+ // Dequeue the pending packets in a new goroutine to not hold up the current
+ // goroutine as writing packets may be a costly operation.
+ //
+ // At the time of writing, when writing packets, a neighbor's link address
+ // is resolved (which ends up obtaining the entry's lock) while holding the
+ // link resolution queue's lock. Dequeuing packets in a new goroutine avoids
+ // a lock ordering violation.
+ go e.cache.nic.linkResQueue.dequeue(ch, linkAddr, len(linkAddr) != 0)
}
}
@@ -174,8 +186,9 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry {
}
*entry = linkAddrEntry{
- addr: k,
- s: incomplete,
+ cache: c,
+ addr: k,
+ s: incomplete,
}
c.cache.table[k] = entry
c.cache.lru.PushFront(entry)
@@ -264,8 +277,9 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt
return true
}
-func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
+func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
c := &linkAddrCache{
+ nic: nic,
ageLimit: ageLimit,
resolutionTimeout: resolutionTimeout,
resolutionAttempts: resolutionAttempts,
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 8c35067c6..88fbbf3fe 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -93,8 +93,14 @@ func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolv
}
}
+func newEmptyNIC() *NIC {
+ n := &NIC{}
+ n.linkResQueue.init(n)
+ return n
+}
+
func TestCacheOverflow(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3)
for i := len(testAddrs) - 1; i >= 0; i-- {
e := testAddrs[i]
c.AddLinkAddress(e.addr, e.linkAddr)
@@ -129,7 +135,7 @@ func TestCacheOverflow(t *testing.T) {
}
func TestCacheConcurrent(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3)
linkRes := &testLinkAddressResolver{cache: c}
var wg sync.WaitGroup
@@ -165,7 +171,7 @@ func TestCacheConcurrent(t *testing.T) {
}
func TestCacheAgeLimit(t *testing.T) {
- c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3)
linkRes := &testLinkAddressResolver{cache: c}
e := testAddrs[0]
@@ -177,7 +183,7 @@ func TestCacheAgeLimit(t *testing.T) {
}
func TestCacheReplace(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3)
e := testAddrs[0]
l2 := e.linkAddr + "2"
c.AddLinkAddress(e.addr, e.linkAddr)
@@ -206,7 +212,7 @@ func TestCacheResolution(t *testing.T) {
//
// Using a large resolution timeout decreases the probability of experiencing
// this race condition and does not affect how long this test takes to run.
- c := newLinkAddrCache(1<<63-1, math.MaxInt64, 1)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1)
linkRes := &testLinkAddressResolver{cache: c}
for i, ta := range testAddrs {
got, err := getBlocking(c, ta.addr, linkRes)
@@ -232,7 +238,7 @@ func TestCacheResolution(t *testing.T) {
}
func TestCacheResolutionFailed(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5)
linkRes := &testLinkAddressResolver{cache: c}
var requestCount uint32
@@ -265,7 +271,7 @@ func TestCacheResolutionFailed(t *testing.T) {
func TestCacheResolutionTimeout(t *testing.T) {
resolverDelay := 500 * time.Millisecond
expiration := resolverDelay / 10
- c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
+ c := newLinkAddrCache(newEmptyNIC(), expiration, 1*time.Millisecond, 3)
linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
e := testAddrs[0]
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 75afb3001..697132689 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -150,6 +150,14 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) {
if ch := e.done; ch != nil {
close(ch)
e.done = nil
+ // Dequeue the pending packets in a new goroutine to not hold up the current
+ // goroutine as writing packets may be a costly operation.
+ //
+ // At the time of writing, when writing packets, a neighbor's link address
+ // is resolved (which ends up obtaining the entry's lock) while holding the
+ // link resolution queue's lock. Dequeuing packets in a new goroutine avoids
+ // a lock ordering violation.
+ go e.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded)
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 447b5c99d..7592cff75 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -139,9 +139,9 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
context: ctx,
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
- linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
}
- nic.linkResQueue.init()
+ nic.linkResQueue.init(nic)
+ nic.linkAddrCache = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts)
nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
// Check for Neighbor Unreachability Detection support.
@@ -303,6 +303,10 @@ func (n *NIC) IsLoopback() bool {
// WritePacket implements NetworkLinkEndpoint.
func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt)
+ return err
+}
+func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, *tcpip.Error) {
// As per relevant RFCs, we should queue packets while we wait for link
// resolution to complete.
//
@@ -320,16 +324,7 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
// be limited to some small value. When a queue overflows, the new arrival
// SHOULD replace the oldest entry. Once address resolution completes, the
// node transmits any queued packets.
- if ch, err := r.Resolve(nil); err != nil {
- if err == tcpip.ErrWouldBlock {
- r.Acquire()
- n.linkResQueue.enqueue(ch, r, protocol, pkt)
- return nil
- }
- return err
- }
-
- return n.writePacket(r.Fields(), gso, protocol, pkt)
+ return n.linkResQueue.enqueue(r, gso, protocol, pkt)
}
// WritePacketToRemote implements NetworkInterface.
@@ -358,33 +353,7 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN
// WritePackets implements NetworkLinkEndpoint.
func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- // As per relevant RFCs, we should queue packets while we wait for link
- // resolution to complete.
- //
- // RFC 1122 section 2.3.2.2 (for IPv4):
- // The link layer SHOULD save (rather than discard) at least
- // one (the latest) packet of each set of packets destined to
- // the same unresolved IP address, and transmit the saved
- // packet when the address has been resolved.
- //
- // RFC 4861 section 7.2.2 (for IPv6):
- // While waiting for address resolution to complete, the sender MUST, for
- // each neighbor, retain a small queue of packets waiting for address
- // resolution to complete. The queue MUST hold at least one packet, and MAY
- // contain more. However, the number of queued packets per neighbor SHOULD
- // be limited to some small value. When a queue overflows, the new arrival
- // SHOULD replace the oldest entry. Once address resolution completes, the
- // node transmits any queued packets.
- if ch, err := r.Resolve(nil); err != nil {
- if err == tcpip.ErrWouldBlock {
- r.Acquire()
- n.linkResQueue.enqueue(ch, r, protocol, &pkts)
- return pkts.Len(), nil
- }
- return 0, err
- }
-
- return n.writePackets(r.Fields(), gso, protocol, pkts)
+ return n.enqueuePacketBuffer(r, gso, protocol, &pkts)
}
func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, *tcpip.Error) {
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index 3ac039c7d..22dfc7960 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -45,135 +45,202 @@ func (p *PacketBufferList) len() int {
}
type pendingPacket struct {
- route *Route
- proto tcpip.NetworkProtocolNumber
- pkt pendingPacketBuffer
+ routeInfo RouteInfo
+ gso *GSO
+ proto tcpip.NetworkProtocolNumber
+ pkt pendingPacketBuffer
}
// packetsPendingLinkResolution is a queue of packets pending link resolution.
//
// Once link resolution completes successfully, the packets will be written.
type packetsPendingLinkResolution struct {
- sync.Mutex
+ nic *NIC
- // The packets to send once the resolver completes.
- packets map[<-chan struct{}][]pendingPacket
+ mu struct {
+ sync.Mutex
- // FIFO of channels used to cancel the oldest goroutine waiting for
- // link-address resolution.
- cancelChans []chan struct{}
-}
+ // The packets to send once the resolver completes.
+ //
+ // The link resolution channel is used as the key for this map.
+ packets map[<-chan struct{}][]pendingPacket
-func (f *packetsPendingLinkResolution) init() {
- f.Lock()
- defer f.Unlock()
- f.packets = make(map[<-chan struct{}][]pendingPacket)
+ // FIFO of channels used to cancel the oldest goroutine waiting for
+ // link-address resolution.
+ //
+ // cancelChans holds the same channels that are used as keys to packets.
+ cancelChans []<-chan struct{}
+ }
}
-func incrementOutgoingPacketErrors(r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) {
+func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) {
n := uint64(pkt.len())
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(n)
+ f.nic.stack.stats.IP.OutgoingPacketErrors.IncrementBy(n)
- // ok may be false if the endpoint's stats do not collect IP-related data.
- if ipEndpointStats, ok := r.outgoingNIC.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok {
+ if ipEndpointStats, ok := f.nic.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok {
ipEndpointStats.IPStats().OutgoingPacketErrors.IncrementBy(n)
}
}
-func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) {
- f.Lock()
- defer f.Unlock()
+func (f *packetsPendingLinkResolution) init(nic *NIC) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.nic = nic
+ f.mu.packets = make(map[<-chan struct{}][]pendingPacket)
+}
- packets, ok := f.packets[ch]
- if len(packets) == maxPendingPacketsPerResolution {
- p := packets[0]
- packets[0] = pendingPacket{}
- packets = packets[1:]
+// dequeue any pending packets associated with ch.
+//
+// If success is true, packets will be written and sent to the given remote link
+// address.
+func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpip.LinkAddress, success bool) {
+ f.mu.Lock()
+ packets, ok := f.mu.packets[ch]
+ delete(f.mu.packets, ch)
+
+ if ok {
+ for i, cancelChan := range f.mu.cancelChans {
+ if cancelChan == ch {
+ f.mu.cancelChans = append(f.mu.cancelChans[:i], f.mu.cancelChans[i+1:]...)
+ break
+ }
+ }
+ }
+
+ f.mu.Unlock()
- incrementOutgoingPacketErrors(r, proto, p.pkt)
+ if ok {
+ f.dequeuePackets(packets, linkAddr, success)
+ }
+}
- p.route.Release()
+func (f *packetsPendingLinkResolution) writePacketBuffer(r RouteInfo, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, *tcpip.Error) {
+ switch pkt := pkt.(type) {
+ case *PacketBuffer:
+ if err := f.nic.writePacket(r, gso, proto, pkt); err != nil {
+ return 0, err
+ }
+ return 1, nil
+ case *PacketBufferList:
+ return f.nic.writePackets(r, gso, proto, *pkt)
+ default:
+ panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt))
}
+}
- if l := len(packets); l >= maxPendingPacketsPerResolution {
- panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution))
+// enqueue a packet to be sent once link resolution completes.
+//
+// If the maximum number of pending resolutions is reached, the packets
+// associated with the oldest link resolution will be dequeued as if they failed
+// link resolution.
+func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, *tcpip.Error) {
+ f.mu.Lock()
+ // Make sure we attempt resolution while holding f's lock so that we avoid
+ // a race where link resolution completes before we enqueue the packets.
+ //
+ // A @ T1: Call ResolvedFields (get link resolution channel)
+ // B @ T2: Complete link resolution, dequeue pending packets
+ // C @ T1: Enqueue packet that already completed link resolution (which will
+ // never dequeue)
+ //
+ // To make sure B does not interleave with A and C, we make sure A and C are
+ // done while holding the lock.
+ routeInfo, ch, err := r.ResolvedFields(nil)
+ switch err {
+ case nil:
+ // The route resolved immediately, so we don't need to wait for link
+ // resolution to send the packet.
+ f.mu.Unlock()
+ return f.writePacketBuffer(routeInfo, gso, proto, pkt)
+ case tcpip.ErrWouldBlock:
+ // We need to wait for link resolution to complete.
+ default:
+ f.mu.Unlock()
+ return 0, err
}
- f.packets[ch] = append(packets, pendingPacket{
- route: r,
- proto: proto,
- pkt: pkt,
+ defer f.mu.Unlock()
+
+ packets, ok := f.mu.packets[ch]
+ packets = append(packets, pendingPacket{
+ routeInfo: routeInfo,
+ gso: gso,
+ proto: proto,
+ pkt: pkt,
})
+ if len(packets) > maxPendingPacketsPerResolution {
+ f.incrementOutgoingPacketErrors(packets[0].proto, packets[0].pkt)
+ packets[0] = pendingPacket{}
+ packets = packets[1:]
+
+ if numPackets := len(packets); numPackets != maxPendingPacketsPerResolution {
+ panic(fmt.Sprintf("holding more queued packets than expected; got = %d, want <= %d", numPackets, maxPendingPacketsPerResolution))
+ }
+ }
+
+ f.mu.packets[ch] = packets
+
if ok {
- return
+ return pkt.len(), nil
}
- // Wait for the link-address resolution to complete.
- cancel := f.newCancelChannelLocked()
- go func() {
- cancelled := false
- select {
- case <-ch:
- case <-cancel:
- cancelled = true
- }
+ cancelledPackets := f.newCancelChannelLocked(ch)
- f.Lock()
- packets, ok := f.packets[ch]
- delete(f.packets, ch)
- f.Unlock()
+ if len(cancelledPackets) != 0 {
+ // Dequeue the pending packets in a new goroutine to not hold up the current
+ // goroutine as handing link resolution failures may be a costly operation.
+ go f.dequeuePackets(cancelledPackets, "" /* linkAddr */, false /* success */)
+ }
- if !ok {
- panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets"))
- }
+ return pkt.len(), nil
+}
- for _, p := range packets {
- if cancelled || p.route.IsResolutionRequired() {
- incrementOutgoingPacketErrors(r, proto, p.pkt)
-
- if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok {
- switch pkt := p.pkt.(type) {
- case *PacketBuffer:
- linkResolvableEP.HandleLinkResolutionFailure(pkt)
- case *PacketBufferList:
- for pb := pkt.Front(); pb != nil; pb = pb.Next() {
- linkResolvableEP.HandleLinkResolutionFailure(pb)
- }
- default:
- panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt))
- }
- }
- } else {
+// newCancelChannelLocked appends the link resolution channel to a FIFO. If the
+// maximum number of pending resolutions is reached, the oldest channel will be
+// removed and its associated pending packets will be returned.
+func (f *packetsPendingLinkResolution) newCancelChannelLocked(newCH <-chan struct{}) []pendingPacket {
+ f.mu.cancelChans = append(f.mu.cancelChans, newCH)
+ if len(f.mu.cancelChans) <= maxPendingResolutions {
+ return nil
+ }
+
+ ch := f.mu.cancelChans[0]
+ f.mu.cancelChans[0] = nil
+ f.mu.cancelChans = f.mu.cancelChans[1:]
+ if l := len(f.mu.cancelChans); l > maxPendingResolutions {
+ panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions))
+ }
+
+ packets, ok := f.mu.packets[ch]
+ if !ok {
+ panic("must have a packet queue for an uncancelled channel")
+ }
+ delete(f.mu.packets, ch)
+
+ return packets
+}
+
+func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, linkAddr tcpip.LinkAddress, success bool) {
+ for _, p := range packets {
+ if success {
+ p.routeInfo.RemoteLinkAddress = linkAddr
+ _, _ = f.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt)
+ } else {
+ f.incrementOutgoingPacketErrors(p.proto, p.pkt)
+
+ if linkResolvableEP, ok := f.nic.getNetworkEndpoint(p.proto).(LinkResolvableNetworkEndpoint); ok {
switch pkt := p.pkt.(type) {
case *PacketBuffer:
- p.route.outgoingNIC.writePacket(p.route.Fields(), nil /* gso */, p.proto, pkt)
+ linkResolvableEP.HandleLinkResolutionFailure(pkt)
case *PacketBufferList:
- p.route.outgoingNIC.writePackets(p.route.Fields(), nil /* gso */, p.proto, *pkt)
+ for pb := pkt.Front(); pb != nil; pb = pb.Next() {
+ linkResolvableEP.HandleLinkResolutionFailure(pb)
+ }
default:
panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt))
}
}
- p.route.Release()
}
- }()
-}
-
-// newCancelChannel creates a channel that can cancel a pending forwarding
-// activity. The oldest channel is closed if the number of open channels would
-// exceed maxPendingResolutions.
-func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} {
- if len(f.cancelChans) == maxPendingResolutions {
- ch := f.cancelChans[0]
- f.cancelChans[0] = nil
- f.cancelChans = f.cancelChans[1:]
- close(ch)
}
- if l := len(f.cancelChans); l >= maxPendingResolutions {
- panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions))
- }
-
- ch := make(chan struct{})
- f.cancelChans = append(f.cancelChans, ch)
- return ch
}
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 1ff7b3a37..093b676aa 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -86,12 +86,21 @@ type RouteInfo struct {
RemoteLinkAddress tcpip.LinkAddress
}
-// Fields returns a RouteInfo with all of r's exported fields. This allows
-// callers to store the route's fields without retaining a reference to it.
+// Fields returns a RouteInfo with all of the known values for the route's
+// fields.
+//
+// If any fields are unknown (e.g. remote link address when it is waiting for
+// link address resolution), they will be unset.
func (r *Route) Fields() RouteInfo {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.fieldsLocked()
+}
+
+func (r *Route) fieldsLocked() RouteInfo {
return RouteInfo{
routeInfo: r.routeInfo,
- RemoteLinkAddress: r.RemoteLinkAddress(),
+ RemoteLinkAddress: r.mu.remoteLinkAddress,
}
}
@@ -306,29 +315,26 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
r.mu.remoteLinkAddress = addr
}
-// Resolve attempts to resolve the link address if necessary.
+// ResolvedFields is like Fields but also attempts to resolve the remote link
+// address if it is not yet known.
//
-// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g.
-// waiting for ARP reply). If address resolution is required, a notification
-// channel is also returned for the caller to block on. The channel is closed
+// If address resolution is required, returns tcpip.ErrWouldBlock and a
+// notification channel for the caller to block on. The channel will be readable
// once address resolution is complete (successful or not). If a callback is
// provided, it will be called when address resolution is complete, regardless
-// of success or failure.
-func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) {
- r.mu.Lock()
-
- if !r.isResolutionRequiredRLocked() {
- // Nothing to do if there is no cache (which does the resolution on cache miss) or
- // link address is already known.
- r.mu.Unlock()
- return nil, nil
+// of success or failure before the notification channel is readable.
+//
+// Note, the route will not cache the remote link address when address
+// resolution completes.
+func (r *Route) ResolvedFields(afterResolve func()) (RouteInfo, <-chan struct{}, *tcpip.Error) {
+ r.mu.RLock()
+ fields := r.fieldsLocked()
+ resolutionRequired := r.isResolutionRequiredRLocked()
+ r.mu.RUnlock()
+ if !resolutionRequired {
+ return fields, nil, nil
}
- // Increment the route's reference count because finishResolution retains a
- // reference to the route and releases it when called.
- r.acquireLocked()
- r.mu.Unlock()
-
nextAddr := r.NextHop
if nextAddr == "" {
nextAddr = r.RemoteAddress
@@ -341,18 +347,15 @@ func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) {
linkAddressResolutionRequestLocalAddr = r.LocalAddress
}
- finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) {
- if ok {
- r.ResolveWith(linkAddress)
- }
+ linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, func(tcpip.LinkAddress, bool) {
if afterResolve != nil {
afterResolve()
}
- r.Release()
+ })
+ if err == nil {
+ fields.RemoteLinkAddress = linkAddr
}
-
- _, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution)
- return ch, err
+ return fields, ch, err
}
// local returns true if the route is a local route.
@@ -371,11 +374,7 @@ func (r *Route) IsResolutionRequired() bool {
}
func (r *Route) isResolutionRequiredRLocked() bool {
- if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() {
- return false
- }
-
- return r.linkRes != nil
+ return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local()
}
func (r *Route) isValidForOutgoing() bool {