diff options
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 47 | ||||
-rw-r--r-- | pkg/tcpip/stack/pending_packets.go | 247 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 67 |
5 files changed, 223 insertions, 166 deletions
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 3c4fa341e..f116f8417 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -32,6 +32,8 @@ var _ LinkAddressCache = (*linkAddrCache)(nil) // // This struct is safe for concurrent use. type linkAddrCache struct { + nic *NIC + // ageLimit is how long a cache entry is valid for. ageLimit time.Duration @@ -79,6 +81,8 @@ type linkAddrEntry struct { // linkAddrEntryEntry access is synchronized by the linkAddrCache lock. linkAddrEntryEntry + cache *linkAddrCache + // TODO(gvisor.dev/issue/5150): move these fields under mu. // mu protects the fields below. mu sync.RWMutex @@ -104,6 +108,14 @@ func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { if ch := e.done; ch != nil { close(ch) e.done = nil + // Dequeue the pending packets in a new goroutine to not hold up the current + // goroutine as writing packets may be a costly operation. + // + // At the time of writing, when writing packets, a neighbor's link address + // is resolved (which ends up obtaining the entry's lock) while holding the + // link resolution queue's lock. Dequeuing packets in a new goroutine avoids + // a lock ordering violation. + go e.cache.nic.linkResQueue.dequeue(ch, linkAddr, len(linkAddr) != 0) } } @@ -174,8 +186,9 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { } *entry = linkAddrEntry{ - addr: k, - s: incomplete, + cache: c, + addr: k, + s: incomplete, } c.cache.table[k] = entry c.cache.lru.PushFront(entry) @@ -264,8 +277,9 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt return true } -func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { +func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { c := &linkAddrCache{ + nic: nic, ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 75afb3001..697132689 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -150,6 +150,14 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { if ch := e.done; ch != nil { close(ch) e.done = nil + // Dequeue the pending packets in a new goroutine to not hold up the current + // goroutine as writing packets may be a costly operation. + // + // At the time of writing, when writing packets, a neighbor's link address + // is resolved (which ends up obtaining the entry's lock) while holding the + // link resolution queue's lock. Dequeuing packets in a new goroutine avoids + // a lock ordering violation. + go e.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 447b5c99d..7592cff75 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -139,9 +139,9 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC context: ctx, stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), - linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), } - nic.linkResQueue.init() + nic.linkResQueue.init(nic) + nic.linkAddrCache = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) // Check for Neighbor Unreachability Detection support. @@ -303,6 +303,10 @@ func (n *NIC) IsLoopback() bool { // WritePacket implements NetworkLinkEndpoint. func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt) + return err +} +func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, *tcpip.Error) { // As per relevant RFCs, we should queue packets while we wait for link // resolution to complete. // @@ -320,16 +324,7 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb // be limited to some small value. When a queue overflows, the new arrival // SHOULD replace the oldest entry. Once address resolution completes, the // node transmits any queued packets. - if ch, err := r.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - r.Acquire() - n.linkResQueue.enqueue(ch, r, protocol, pkt) - return nil - } - return err - } - - return n.writePacket(r.Fields(), gso, protocol, pkt) + return n.linkResQueue.enqueue(r, gso, protocol, pkt) } // WritePacketToRemote implements NetworkInterface. @@ -358,33 +353,7 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN // WritePackets implements NetworkLinkEndpoint. func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // As per relevant RFCs, we should queue packets while we wait for link - // resolution to complete. - // - // RFC 1122 section 2.3.2.2 (for IPv4): - // The link layer SHOULD save (rather than discard) at least - // one (the latest) packet of each set of packets destined to - // the same unresolved IP address, and transmit the saved - // packet when the address has been resolved. - // - // RFC 4861 section 7.2.2 (for IPv6): - // While waiting for address resolution to complete, the sender MUST, for - // each neighbor, retain a small queue of packets waiting for address - // resolution to complete. The queue MUST hold at least one packet, and MAY - // contain more. However, the number of queued packets per neighbor SHOULD - // be limited to some small value. When a queue overflows, the new arrival - // SHOULD replace the oldest entry. Once address resolution completes, the - // node transmits any queued packets. - if ch, err := r.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - r.Acquire() - n.linkResQueue.enqueue(ch, r, protocol, &pkts) - return pkts.Len(), nil - } - return 0, err - } - - return n.writePackets(r.Fields(), gso, protocol, pkts) + return n.enqueuePacketBuffer(r, gso, protocol, &pkts) } func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, *tcpip.Error) { diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 3ac039c7d..22dfc7960 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -45,135 +45,202 @@ func (p *PacketBufferList) len() int { } type pendingPacket struct { - route *Route - proto tcpip.NetworkProtocolNumber - pkt pendingPacketBuffer + routeInfo RouteInfo + gso *GSO + proto tcpip.NetworkProtocolNumber + pkt pendingPacketBuffer } // packetsPendingLinkResolution is a queue of packets pending link resolution. // // Once link resolution completes successfully, the packets will be written. type packetsPendingLinkResolution struct { - sync.Mutex + nic *NIC - // The packets to send once the resolver completes. - packets map[<-chan struct{}][]pendingPacket + mu struct { + sync.Mutex - // FIFO of channels used to cancel the oldest goroutine waiting for - // link-address resolution. - cancelChans []chan struct{} -} + // The packets to send once the resolver completes. + // + // The link resolution channel is used as the key for this map. + packets map[<-chan struct{}][]pendingPacket -func (f *packetsPendingLinkResolution) init() { - f.Lock() - defer f.Unlock() - f.packets = make(map[<-chan struct{}][]pendingPacket) + // FIFO of channels used to cancel the oldest goroutine waiting for + // link-address resolution. + // + // cancelChans holds the same channels that are used as keys to packets. + cancelChans []<-chan struct{} + } } -func incrementOutgoingPacketErrors(r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) { +func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) { n := uint64(pkt.len()) - r.Stats().IP.OutgoingPacketErrors.IncrementBy(n) + f.nic.stack.stats.IP.OutgoingPacketErrors.IncrementBy(n) - // ok may be false if the endpoint's stats do not collect IP-related data. - if ipEndpointStats, ok := r.outgoingNIC.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok { + if ipEndpointStats, ok := f.nic.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok { ipEndpointStats.IPStats().OutgoingPacketErrors.IncrementBy(n) } } -func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) { - f.Lock() - defer f.Unlock() +func (f *packetsPendingLinkResolution) init(nic *NIC) { + f.mu.Lock() + defer f.mu.Unlock() + f.nic = nic + f.mu.packets = make(map[<-chan struct{}][]pendingPacket) +} - packets, ok := f.packets[ch] - if len(packets) == maxPendingPacketsPerResolution { - p := packets[0] - packets[0] = pendingPacket{} - packets = packets[1:] +// dequeue any pending packets associated with ch. +// +// If success is true, packets will be written and sent to the given remote link +// address. +func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpip.LinkAddress, success bool) { + f.mu.Lock() + packets, ok := f.mu.packets[ch] + delete(f.mu.packets, ch) + + if ok { + for i, cancelChan := range f.mu.cancelChans { + if cancelChan == ch { + f.mu.cancelChans = append(f.mu.cancelChans[:i], f.mu.cancelChans[i+1:]...) + break + } + } + } + + f.mu.Unlock() - incrementOutgoingPacketErrors(r, proto, p.pkt) + if ok { + f.dequeuePackets(packets, linkAddr, success) + } +} - p.route.Release() +func (f *packetsPendingLinkResolution) writePacketBuffer(r RouteInfo, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, *tcpip.Error) { + switch pkt := pkt.(type) { + case *PacketBuffer: + if err := f.nic.writePacket(r, gso, proto, pkt); err != nil { + return 0, err + } + return 1, nil + case *PacketBufferList: + return f.nic.writePackets(r, gso, proto, *pkt) + default: + panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt)) } +} - if l := len(packets); l >= maxPendingPacketsPerResolution { - panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution)) +// enqueue a packet to be sent once link resolution completes. +// +// If the maximum number of pending resolutions is reached, the packets +// associated with the oldest link resolution will be dequeued as if they failed +// link resolution. +func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, *tcpip.Error) { + f.mu.Lock() + // Make sure we attempt resolution while holding f's lock so that we avoid + // a race where link resolution completes before we enqueue the packets. + // + // A @ T1: Call ResolvedFields (get link resolution channel) + // B @ T2: Complete link resolution, dequeue pending packets + // C @ T1: Enqueue packet that already completed link resolution (which will + // never dequeue) + // + // To make sure B does not interleave with A and C, we make sure A and C are + // done while holding the lock. + routeInfo, ch, err := r.ResolvedFields(nil) + switch err { + case nil: + // The route resolved immediately, so we don't need to wait for link + // resolution to send the packet. + f.mu.Unlock() + return f.writePacketBuffer(routeInfo, gso, proto, pkt) + case tcpip.ErrWouldBlock: + // We need to wait for link resolution to complete. + default: + f.mu.Unlock() + return 0, err } - f.packets[ch] = append(packets, pendingPacket{ - route: r, - proto: proto, - pkt: pkt, + defer f.mu.Unlock() + + packets, ok := f.mu.packets[ch] + packets = append(packets, pendingPacket{ + routeInfo: routeInfo, + gso: gso, + proto: proto, + pkt: pkt, }) + if len(packets) > maxPendingPacketsPerResolution { + f.incrementOutgoingPacketErrors(packets[0].proto, packets[0].pkt) + packets[0] = pendingPacket{} + packets = packets[1:] + + if numPackets := len(packets); numPackets != maxPendingPacketsPerResolution { + panic(fmt.Sprintf("holding more queued packets than expected; got = %d, want <= %d", numPackets, maxPendingPacketsPerResolution)) + } + } + + f.mu.packets[ch] = packets + if ok { - return + return pkt.len(), nil } - // Wait for the link-address resolution to complete. - cancel := f.newCancelChannelLocked() - go func() { - cancelled := false - select { - case <-ch: - case <-cancel: - cancelled = true - } + cancelledPackets := f.newCancelChannelLocked(ch) - f.Lock() - packets, ok := f.packets[ch] - delete(f.packets, ch) - f.Unlock() + if len(cancelledPackets) != 0 { + // Dequeue the pending packets in a new goroutine to not hold up the current + // goroutine as handing link resolution failures may be a costly operation. + go f.dequeuePackets(cancelledPackets, "" /* linkAddr */, false /* success */) + } - if !ok { - panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets")) - } + return pkt.len(), nil +} - for _, p := range packets { - if cancelled || p.route.IsResolutionRequired() { - incrementOutgoingPacketErrors(r, proto, p.pkt) - - if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok { - switch pkt := p.pkt.(type) { - case *PacketBuffer: - linkResolvableEP.HandleLinkResolutionFailure(pkt) - case *PacketBufferList: - for pb := pkt.Front(); pb != nil; pb = pb.Next() { - linkResolvableEP.HandleLinkResolutionFailure(pb) - } - default: - panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt)) - } - } - } else { +// newCancelChannelLocked appends the link resolution channel to a FIFO. If the +// maximum number of pending resolutions is reached, the oldest channel will be +// removed and its associated pending packets will be returned. +func (f *packetsPendingLinkResolution) newCancelChannelLocked(newCH <-chan struct{}) []pendingPacket { + f.mu.cancelChans = append(f.mu.cancelChans, newCH) + if len(f.mu.cancelChans) <= maxPendingResolutions { + return nil + } + + ch := f.mu.cancelChans[0] + f.mu.cancelChans[0] = nil + f.mu.cancelChans = f.mu.cancelChans[1:] + if l := len(f.mu.cancelChans); l > maxPendingResolutions { + panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions)) + } + + packets, ok := f.mu.packets[ch] + if !ok { + panic("must have a packet queue for an uncancelled channel") + } + delete(f.mu.packets, ch) + + return packets +} + +func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, linkAddr tcpip.LinkAddress, success bool) { + for _, p := range packets { + if success { + p.routeInfo.RemoteLinkAddress = linkAddr + _, _ = f.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt) + } else { + f.incrementOutgoingPacketErrors(p.proto, p.pkt) + + if linkResolvableEP, ok := f.nic.getNetworkEndpoint(p.proto).(LinkResolvableNetworkEndpoint); ok { switch pkt := p.pkt.(type) { case *PacketBuffer: - p.route.outgoingNIC.writePacket(p.route.Fields(), nil /* gso */, p.proto, pkt) + linkResolvableEP.HandleLinkResolutionFailure(pkt) case *PacketBufferList: - p.route.outgoingNIC.writePackets(p.route.Fields(), nil /* gso */, p.proto, *pkt) + for pb := pkt.Front(); pb != nil; pb = pb.Next() { + linkResolvableEP.HandleLinkResolutionFailure(pb) + } default: panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt)) } } - p.route.Release() } - }() -} - -// newCancelChannel creates a channel that can cancel a pending forwarding -// activity. The oldest channel is closed if the number of open channels would -// exceed maxPendingResolutions. -func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} { - if len(f.cancelChans) == maxPendingResolutions { - ch := f.cancelChans[0] - f.cancelChans[0] = nil - f.cancelChans = f.cancelChans[1:] - close(ch) } - if l := len(f.cancelChans); l >= maxPendingResolutions { - panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions)) - } - - ch := make(chan struct{}) - f.cancelChans = append(f.cancelChans, ch) - return ch } diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 1ff7b3a37..093b676aa 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -86,12 +86,21 @@ type RouteInfo struct { RemoteLinkAddress tcpip.LinkAddress } -// Fields returns a RouteInfo with all of r's exported fields. This allows -// callers to store the route's fields without retaining a reference to it. +// Fields returns a RouteInfo with all of the known values for the route's +// fields. +// +// If any fields are unknown (e.g. remote link address when it is waiting for +// link address resolution), they will be unset. func (r *Route) Fields() RouteInfo { + r.mu.RLock() + defer r.mu.RUnlock() + return r.fieldsLocked() +} + +func (r *Route) fieldsLocked() RouteInfo { return RouteInfo{ routeInfo: r.routeInfo, - RemoteLinkAddress: r.RemoteLinkAddress(), + RemoteLinkAddress: r.mu.remoteLinkAddress, } } @@ -306,29 +315,26 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) { r.mu.remoteLinkAddress = addr } -// Resolve attempts to resolve the link address if necessary. +// ResolvedFields is like Fields but also attempts to resolve the remote link +// address if it is not yet known. // -// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g. -// waiting for ARP reply). If address resolution is required, a notification -// channel is also returned for the caller to block on. The channel is closed +// If address resolution is required, returns tcpip.ErrWouldBlock and a +// notification channel for the caller to block on. The channel will be readable // once address resolution is complete (successful or not). If a callback is // provided, it will be called when address resolution is complete, regardless -// of success or failure. -func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { - r.mu.Lock() - - if !r.isResolutionRequiredRLocked() { - // Nothing to do if there is no cache (which does the resolution on cache miss) or - // link address is already known. - r.mu.Unlock() - return nil, nil +// of success or failure before the notification channel is readable. +// +// Note, the route will not cache the remote link address when address +// resolution completes. +func (r *Route) ResolvedFields(afterResolve func()) (RouteInfo, <-chan struct{}, *tcpip.Error) { + r.mu.RLock() + fields := r.fieldsLocked() + resolutionRequired := r.isResolutionRequiredRLocked() + r.mu.RUnlock() + if !resolutionRequired { + return fields, nil, nil } - // Increment the route's reference count because finishResolution retains a - // reference to the route and releases it when called. - r.acquireLocked() - r.mu.Unlock() - nextAddr := r.NextHop if nextAddr == "" { nextAddr = r.RemoteAddress @@ -341,18 +347,15 @@ func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { linkAddressResolutionRequestLocalAddr = r.LocalAddress } - finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) { - if ok { - r.ResolveWith(linkAddress) - } + linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, func(tcpip.LinkAddress, bool) { if afterResolve != nil { afterResolve() } - r.Release() + }) + if err == nil { + fields.RemoteLinkAddress = linkAddr } - - _, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution) - return ch, err + return fields, ch, err } // local returns true if the route is a local route. @@ -371,11 +374,7 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isResolutionRequiredRLocked() bool { - if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() { - return false - } - - return r.linkRes != nil + return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local() } func (r *Route) isValidForOutgoing() bool { |