summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go112
1 files changed, 59 insertions, 53 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index ff1845bfb..d4c0359e8 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -409,61 +409,45 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
return false
}
- eps.mu.RLock()
-
- // Determine which transport endpoint or endpoints to deliver this packet to.
// If the packet is a UDP broadcast or multicast, then find all matching
- // transport endpoints. If the packet is a TCP packet with a non-unicast
- // source or destination address, then do nothing further and instruct
- // the caller to do the same.
- var destEps []*endpointsByNic
- switch protocol {
- case header.UDPProtocolNumber:
- if isMulticastOrBroadcast(id.LocalAddress) {
- destEps = d.findAllEndpointsLocked(eps, id)
- break
- }
-
- if ep := d.findEndpointLocked(eps, id); ep != nil {
- destEps = append(destEps, ep)
+ // transport endpoints.
+ if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ eps.mu.RLock()
+ destEPs := d.findAllEndpointsLocked(eps, id)
+ eps.mu.RUnlock()
+ // Fail if we didn't find at least one matching transport endpoint.
+ if len(destEPs) == 0 {
+ r.Stats().UDP.UnknownPortErrors.Increment()
+ return false
}
-
- case header.TCPProtocolNumber:
- if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) {
- // TCP can only be used to communicate between a single
- // source and a single destination; the addresses must
- // be unicast.
- eps.mu.RUnlock()
- r.Stats().TCP.InvalidSegmentsReceived.Increment()
- return true
+ // handlePacket takes ownership of pkt, so each endpoint needs its own
+ // copy except for the final one.
+ for _, ep := range destEPs[:len(destEPs)-1] {
+ ep.handlePacket(r, id, pkt.Clone())
}
+ destEPs[len(destEPs)-1].handlePacket(r, id, pkt)
+ return true
+ }
- fallthrough
-
- default:
- if ep := d.findEndpointLocked(eps, id); ep != nil {
- destEps = append(destEps, ep)
- }
+ // If the packet is a TCP packet with a non-unicast source or destination
+ // address, then do nothing further and instruct the caller to do the same.
+ if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) {
+ // TCP can only be used to communicate between a single source and a
+ // single destination; the addresses must be unicast.
+ r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ return true
}
+ eps.mu.RLock()
+ ep := d.findEndpointLocked(eps, id)
eps.mu.RUnlock()
-
- // Fail if we didn't find at least one matching transport endpoint.
- if len(destEps) == 0 {
- // UDP packet could not be delivered to an unknown destination port.
+ if ep == nil {
if protocol == header.UDPProtocolNumber {
r.Stats().UDP.UnknownPortErrors.Increment()
}
return false
}
-
- // HandlePacket takes ownership of pkt, so each endpoint needs its own
- // copy except for the final one.
- for _, ep := range destEps[:len(destEps)-1] {
- ep.handlePacket(r, id, pkt.Clone())
- }
- destEps[len(destEps)-1].handlePacket(r, id, pkt)
-
+ ep.handlePacket(r, id, pkt)
return true
}
@@ -515,11 +499,17 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
return true
}
-func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic {
- var matchedEPs []*endpointsByNic
+// iterEndpointsLocked yields all endpointsByNic in eps that match id, in
+// descending order of match quality. If a call to yield returns false,
+// iterEndpointsLocked stops iteration and returns immediately.
+//
+// Preconditions: eps.mu must be locked.
+func (d *transportDemuxer) iterEndpointsLocked(eps *transportEndpoints, id TransportEndpointID, yield func(*endpointsByNic) bool) {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
// Try to find a match with the id minus the local address.
@@ -527,7 +517,9 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
// Try to find a match with the id minus the remote part.
@@ -535,14 +527,26 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr
nid.RemoteAddress = ""
nid.RemotePort = 0
if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
// Try to find a match with only the local port.
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
+}
+
+func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic {
+ var matchedEPs []*endpointsByNic
+ d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool {
+ matchedEPs = append(matchedEPs, ep)
+ return true
+ })
return matchedEPs
}
@@ -580,10 +584,12 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
// findEndpointLocked returns the endpoint that most closely matches the given
// id.
func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic {
- if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 {
- return matchedEPs[0]
- }
- return nil
+ var matchedEP *endpointsByNic
+ d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool {
+ matchedEP = ep
+ return false
+ })
+ return matchedEP
}
// registerRawEndpoint registers the given endpoint with the dispatcher such