summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJamie Liu <jamieliu@google.com>2020-03-13 12:20:09 -0700
committergVisor bot <gvisor-bot@google.com>2020-03-13 12:22:19 -0700
commit86409c91813256f45ebcb08efeac9d7f9e56a804 (patch)
tree7507be8b78b3bc47fc70846e61725b729dbbd73c
parentb78cee3bae142eb5c602d51874d0cbad274777e2 (diff)
Avoid unnecessary work in transportDemuxer.deliverPacket().
- Don't allocate []*endpointsByNic in transportDemuxer.deliverPacket() unless actually needed for UDP broadcast/multicast. - Don't allocate []*endpointsByNic via transportDemuxer.findEndpointLocked() => transportDemuxer.findAllEndpointsLocked(). - Skip unnecessary map lookups in transportDemuxer.findEndpointLocked() => transportDemuxer.findAllEndpointsLocked() (now iterEndpointsLocked). For most deliverable packets other than UDP broadcast/multicast packets, this saves two slice allocations and three map lookups per packet. PiperOrigin-RevId: 300804135
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go112
1 files changed, 59 insertions, 53 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index ff1845bfb..d4c0359e8 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -409,61 +409,45 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
return false
}
- eps.mu.RLock()
-
- // Determine which transport endpoint or endpoints to deliver this packet to.
// If the packet is a UDP broadcast or multicast, then find all matching
- // transport endpoints. If the packet is a TCP packet with a non-unicast
- // source or destination address, then do nothing further and instruct
- // the caller to do the same.
- var destEps []*endpointsByNic
- switch protocol {
- case header.UDPProtocolNumber:
- if isMulticastOrBroadcast(id.LocalAddress) {
- destEps = d.findAllEndpointsLocked(eps, id)
- break
- }
-
- if ep := d.findEndpointLocked(eps, id); ep != nil {
- destEps = append(destEps, ep)
+ // transport endpoints.
+ if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ eps.mu.RLock()
+ destEPs := d.findAllEndpointsLocked(eps, id)
+ eps.mu.RUnlock()
+ // Fail if we didn't find at least one matching transport endpoint.
+ if len(destEPs) == 0 {
+ r.Stats().UDP.UnknownPortErrors.Increment()
+ return false
}
-
- case header.TCPProtocolNumber:
- if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) {
- // TCP can only be used to communicate between a single
- // source and a single destination; the addresses must
- // be unicast.
- eps.mu.RUnlock()
- r.Stats().TCP.InvalidSegmentsReceived.Increment()
- return true
+ // handlePacket takes ownership of pkt, so each endpoint needs its own
+ // copy except for the final one.
+ for _, ep := range destEPs[:len(destEPs)-1] {
+ ep.handlePacket(r, id, pkt.Clone())
}
+ destEPs[len(destEPs)-1].handlePacket(r, id, pkt)
+ return true
+ }
- fallthrough
-
- default:
- if ep := d.findEndpointLocked(eps, id); ep != nil {
- destEps = append(destEps, ep)
- }
+ // If the packet is a TCP packet with a non-unicast source or destination
+ // address, then do nothing further and instruct the caller to do the same.
+ if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) {
+ // TCP can only be used to communicate between a single source and a
+ // single destination; the addresses must be unicast.
+ r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ return true
}
+ eps.mu.RLock()
+ ep := d.findEndpointLocked(eps, id)
eps.mu.RUnlock()
-
- // Fail if we didn't find at least one matching transport endpoint.
- if len(destEps) == 0 {
- // UDP packet could not be delivered to an unknown destination port.
+ if ep == nil {
if protocol == header.UDPProtocolNumber {
r.Stats().UDP.UnknownPortErrors.Increment()
}
return false
}
-
- // HandlePacket takes ownership of pkt, so each endpoint needs its own
- // copy except for the final one.
- for _, ep := range destEps[:len(destEps)-1] {
- ep.handlePacket(r, id, pkt.Clone())
- }
- destEps[len(destEps)-1].handlePacket(r, id, pkt)
-
+ ep.handlePacket(r, id, pkt)
return true
}
@@ -515,11 +499,17 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
return true
}
-func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic {
- var matchedEPs []*endpointsByNic
+// iterEndpointsLocked yields all endpointsByNic in eps that match id, in
+// descending order of match quality. If a call to yield returns false,
+// iterEndpointsLocked stops iteration and returns immediately.
+//
+// Preconditions: eps.mu must be locked.
+func (d *transportDemuxer) iterEndpointsLocked(eps *transportEndpoints, id TransportEndpointID, yield func(*endpointsByNic) bool) {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
// Try to find a match with the id minus the local address.
@@ -527,7 +517,9 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
// Try to find a match with the id minus the remote part.
@@ -535,14 +527,26 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr
nid.RemoteAddress = ""
nid.RemotePort = 0
if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
// Try to find a match with only the local port.
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
+}
+
+func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic {
+ var matchedEPs []*endpointsByNic
+ d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool {
+ matchedEPs = append(matchedEPs, ep)
+ return true
+ })
return matchedEPs
}
@@ -580,10 +584,12 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
// findEndpointLocked returns the endpoint that most closely matches the given
// id.
func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic {
- if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 {
- return matchedEPs[0]
- }
- return nil
+ var matchedEP *endpointsByNic
+ d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool {
+ matchedEP = ep
+ return false
+ })
+ return matchedEP
}
// registerRawEndpoint registers the given endpoint with the dispatcher such