summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/transport_demuxer.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/transport_demuxer.go')
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go47
1 files changed, 26 insertions, 21 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 8c768c299..92267ce4d 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -63,7 +63,7 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, v
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
+ if isMulticastOrBroadcast(id.LocalAddress) {
mpep.handlePacketAll(r, id, vv)
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
return
@@ -338,23 +338,14 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
return false
}
- // If a sender bound to the Loopback interface sends a broadcast,
- // that broadcast must not be delivered to the sender.
- if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort {
- return false
- }
-
- // If the packet is a broadcast, then find all matching transport endpoints.
- // Otherwise, try to find a single matching transport endpoint.
- destEps := make([]*endpointsByNic, 0, 1)
eps.mu.RLock()
- if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast {
- for epID, endpoint := range eps.endpoints {
- if epID.LocalPort == id.LocalPort {
- destEps = append(destEps, endpoint)
- }
- }
+ // Determine which transport endpoint or endpoints to deliver this packet to.
+ // If the packet is a broadcast or multicast, then find all matching
+ // transport endpoints.
+ var destEps []*endpointsByNic
+ if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ destEps = d.findAllEndpointsLocked(eps, vv, id)
} else if ep := d.findEndpointLocked(eps, vv, id); ep != nil {
destEps = append(destEps, ep)
}
@@ -426,10 +417,11 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
return true
}
-func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic {
+func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic {
+ var matchedEPs []*endpointsByNic
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
}
// Try to find a match with the id minus the local address.
@@ -437,7 +429,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
}
// Try to find a match with the id minus the remote part.
@@ -445,15 +437,24 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
nid.RemoteAddress = ""
nid.RemotePort = 0
if ep, ok := eps.endpoints[nid]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
}
// Try to find a match with only the local port.
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
}
+ return matchedEPs
+}
+
+// findEndpointLocked returns the endpoint that most closely matches the given
+// id.
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic {
+ if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 {
+ return matchedEPs[0]
+ }
return nil
}
@@ -491,3 +492,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
}
}
}
+
+func isMulticastOrBroadcast(addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
+}