summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/nic.go27
-rw-r--r--pkg/tcpip/stack/route.go13
-rw-r--r--pkg/tcpip/stack/stack.go2
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go47
4 files changed, 49 insertions, 40 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index f6106f762..5993fe582 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -104,6 +104,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
func (n *NIC) enable() *tcpip.Error {
n.attachLinkEndpoint()
+ // Create an endpoint to receive broadcast packets on this interface.
+ if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
+ if err := n.AddAddress(tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize},
+ }, NeverPrimaryEndpoint); err != nil {
+ return err
+ }
+ }
+
// Join the IPv6 All-Nodes Multicast group if the stack is configured to
// use IPv6. This is required to ensure that this node properly receives
// and responds to the various NDP messages that are destined to the
@@ -372,7 +382,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
for nid, ref := range n.endpoints {
- // Don't include expired or tempory endpoints to avoid confusion and
+ // Don't include expired or temporary endpoints to avoid confusion and
// prevent the caller from using those.
switch ref.getKind() {
case permanentExpired, temporary:
@@ -624,21 +634,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
n.stack.AddLinkAddress(n.id, src, remote)
- // If the packet is destined to the IPv4 Broadcast address, then make a
- // route to each IPv4 network endpoint and let each endpoint handle the
- // packet.
- if dst == header.IPv4Broadcast {
- // n.endpoints is mutex protected so acquire lock.
- n.mu.RLock()
- for _, ref := range n.endpoints {
- if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
- handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
- }
- }
- n.mu.RUnlock()
- return
- }
-
if ref := n.getRef(protocol, dst); ref != nil {
handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
return
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 5c8b7977a..0b09e6517 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -59,6 +59,8 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
loop = PacketLoop
} else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
loop |= PacketLoop
+ } else if remoteAddr == header.IPv4Broadcast {
+ loop |= PacketLoop
}
return Route{
@@ -208,10 +210,17 @@ func (r *Route) Clone() Route {
return *r
}
-// MakeLoopedRoute duplicates the given route and tweaks it in case of multicast.
+// MakeLoopedRoute duplicates the given route with special handling for routes
+// used for sending multicast or broadcast packets. In those cases the
+// multicast/broadcast address is the remote address when sending out, but for
+// incoming (looped) packets it becomes the local address. Similarly, the local
+// interface address that was the local address going out becomes the remote
+// address coming in. This is different to unicast routes where local and
+// remote addresses remain the same as they identify location (local vs remote)
+// not direction (source vs destination).
func (r *Route) MakeLoopedRoute() Route {
l := r.Clone()
- if header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) {
+ if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) {
l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress
l.RemoteLinkAddress = l.LocalLinkAddress
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 90c2cf1be..ff574a055 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -902,7 +902,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
} else {
for _, route := range s.routeTable {
- if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) {
+ if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
continue
}
if nic, ok := s.nics[route.NIC]; ok {
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 8c768c299..92267ce4d 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -63,7 +63,7 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, v
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
+ if isMulticastOrBroadcast(id.LocalAddress) {
mpep.handlePacketAll(r, id, vv)
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
return
@@ -338,23 +338,14 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
return false
}
- // If a sender bound to the Loopback interface sends a broadcast,
- // that broadcast must not be delivered to the sender.
- if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort {
- return false
- }
-
- // If the packet is a broadcast, then find all matching transport endpoints.
- // Otherwise, try to find a single matching transport endpoint.
- destEps := make([]*endpointsByNic, 0, 1)
eps.mu.RLock()
- if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast {
- for epID, endpoint := range eps.endpoints {
- if epID.LocalPort == id.LocalPort {
- destEps = append(destEps, endpoint)
- }
- }
+ // Determine which transport endpoint or endpoints to deliver this packet to.
+ // If the packet is a broadcast or multicast, then find all matching
+ // transport endpoints.
+ var destEps []*endpointsByNic
+ if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ destEps = d.findAllEndpointsLocked(eps, vv, id)
} else if ep := d.findEndpointLocked(eps, vv, id); ep != nil {
destEps = append(destEps, ep)
}
@@ -426,10 +417,11 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
return true
}
-func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic {
+func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic {
+ var matchedEPs []*endpointsByNic
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
}
// Try to find a match with the id minus the local address.
@@ -437,7 +429,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
}
// Try to find a match with the id minus the remote part.
@@ -445,15 +437,24 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
nid.RemoteAddress = ""
nid.RemotePort = 0
if ep, ok := eps.endpoints[nid]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
}
// Try to find a match with only the local port.
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
}
+ return matchedEPs
+}
+
+// findEndpointLocked returns the endpoint that most closely matches the given
+// id.
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic {
+ if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 {
+ return matchedEPs[0]
+ }
return nil
}
@@ -491,3 +492,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
}
}
}
+
+func isMulticastOrBroadcast(addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
+}