summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/nic.go44
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go18
2 files changed, 23 insertions, 39 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 863ef6bee..1f1a1426b 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -665,33 +665,15 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
}
}
- // Check if address is a broadcast address for the endpoint's network.
- //
- // Only IPv4 has a notion of broadcast addresses.
if protocol == header.IPv4ProtocolNumber {
- if ref := n.getRefForBroadcastRLocked(address); ref != nil {
+ if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil {
n.mu.RUnlock()
return ref
}
}
-
- // A usable reference was not found, create a temporary one if requested by
- // the caller or if the IPv4 address is found in the NIC's subnets and the NIC
- // is a loopback interface.
- createTempEP := spoofingOrPromiscuous
- if !createTempEP && n.isLoopback() && protocol == header.IPv4ProtocolNumber {
- for _, r := range n.mu.endpoints {
- addr := r.addrWithPrefix()
- subnet := addr.Subnet()
- if subnet.Contains(address) {
- createTempEP = true
- break
- }
- }
- }
n.mu.RUnlock()
- if !createTempEP {
+ if !spoofingOrPromiscuous {
return nil
}
@@ -704,20 +686,21 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
return ref
}
-// getRefForBroadcastLocked returns an endpoint where address is the IPv4
-// broadcast address for the endpoint's network.
+// getRefForBroadcastOrLoopbackRLocked returns an endpoint whose address is the
+// broadcast address for the endpoint's network or an address in the endpoint's
+// subnet if the NIC is a loopback interface. This matches linux behaviour.
//
-// n.mu MUST be read locked.
-func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint {
+// n.mu MUST be read or write locked.
+func (n *NIC) getIPv4RefForBroadcastOrLoopbackRLocked(address tcpip.Address) *referencedNetworkEndpoint {
for _, ref := range n.mu.endpoints {
- // Only IPv4 has a notion of broadcast addresses.
+ // Only IPv4 has a notion of broadcast addresses or considers the loopback
+ // interface bound to an address's whole subnet (on linux).
if ref.protocol != header.IPv4ProtocolNumber {
continue
}
- addr := ref.addrWithPrefix()
- subnet := addr.Subnet()
- if subnet.IsBroadcast(address) && ref.tryIncRef() {
+ subnet := ref.addrWithPrefix().Subnet()
+ if (subnet.IsBroadcast(address) || (n.isLoopback() && subnet.Contains(address))) && ref.isValidForOutgoingRLocked() && ref.tryIncRef() {
return ref
}
}
@@ -745,11 +728,8 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add
n.removeEndpointLocked(ref)
}
- // Check if address is a broadcast address for an endpoint's network.
- //
- // Only IPv4 has a notion of broadcast addresses.
if protocol == header.IPv4ProtocolNumber {
- if ref := n.getRefForBroadcastRLocked(address); ref != nil {
+ if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil {
return ref
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index b902c6ca9..0774b5382 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if isMulticastOrBroadcast(id.LocalAddress) {
+ if isInboundMulticastOrBroadcast(r) {
mpep.handlePacketAll(r, id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
@@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
- if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) {
eps.mu.RLock()
destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
@@ -546,7 +546,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a TCP packet with a non-unicast source or destination
// address, then do nothing further and instruct the caller to do the same.
- if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) {
+ if protocol == header.TCPProtocolNumber && (!isInboundUnicast(r) || !isOutboundUnicast(r)) {
// TCP can only be used to communicate between a single source and a
// single destination; the addresses must be unicast.
r.Stats().TCP.InvalidSegmentsReceived.Increment()
@@ -677,10 +677,14 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
eps.mu.Unlock()
}
-func isMulticastOrBroadcast(addr tcpip.Address) bool {
- return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
+func isInboundMulticastOrBroadcast(r *Route) bool {
+ return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress)
}
-func isUnicast(addr tcpip.Address) bool {
- return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr)
+func isInboundUnicast(r *Route) bool {
+ return r.LocalAddress != header.IPv4Any && r.LocalAddress != header.IPv6Any && !isInboundMulticastOrBroadcast(r)
+}
+
+func isOutboundUnicast(r *Route) bool {
+ return r.RemoteAddress != header.IPv4Any && r.RemoteAddress != header.IPv6Any && !r.IsOutboundBroadcast() && !header.IsV4MulticastAddress(r.RemoteAddress) && !header.IsV6MulticastAddress(r.RemoteAddress)
}