diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/nic.go | 44 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 18 |
2 files changed, 23 insertions, 39 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 863ef6bee..1f1a1426b 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -665,33 +665,15 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t } } - // Check if address is a broadcast address for the endpoint's network. - // - // Only IPv4 has a notion of broadcast addresses. if protocol == header.IPv4ProtocolNumber { - if ref := n.getRefForBroadcastRLocked(address); ref != nil { + if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil { n.mu.RUnlock() return ref } } - - // A usable reference was not found, create a temporary one if requested by - // the caller or if the IPv4 address is found in the NIC's subnets and the NIC - // is a loopback interface. - createTempEP := spoofingOrPromiscuous - if !createTempEP && n.isLoopback() && protocol == header.IPv4ProtocolNumber { - for _, r := range n.mu.endpoints { - addr := r.addrWithPrefix() - subnet := addr.Subnet() - if subnet.Contains(address) { - createTempEP = true - break - } - } - } n.mu.RUnlock() - if !createTempEP { + if !spoofingOrPromiscuous { return nil } @@ -704,20 +686,21 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t return ref } -// getRefForBroadcastLocked returns an endpoint where address is the IPv4 -// broadcast address for the endpoint's network. +// getRefForBroadcastOrLoopbackRLocked returns an endpoint whose address is the +// broadcast address for the endpoint's network or an address in the endpoint's +// subnet if the NIC is a loopback interface. This matches linux behaviour. // -// n.mu MUST be read locked. -func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint { +// n.mu MUST be read or write locked. +func (n *NIC) getIPv4RefForBroadcastOrLoopbackRLocked(address tcpip.Address) *referencedNetworkEndpoint { for _, ref := range n.mu.endpoints { - // Only IPv4 has a notion of broadcast addresses. + // Only IPv4 has a notion of broadcast addresses or considers the loopback + // interface bound to an address's whole subnet (on linux). if ref.protocol != header.IPv4ProtocolNumber { continue } - addr := ref.addrWithPrefix() - subnet := addr.Subnet() - if subnet.IsBroadcast(address) && ref.tryIncRef() { + subnet := ref.addrWithPrefix().Subnet() + if (subnet.IsBroadcast(address) || (n.isLoopback() && subnet.Contains(address))) && ref.isValidForOutgoingRLocked() && ref.tryIncRef() { return ref } } @@ -745,11 +728,8 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add n.removeEndpointLocked(ref) } - // Check if address is a broadcast address for an endpoint's network. - // - // Only IPv4 has a notion of broadcast addresses. if protocol == header.IPv4ProtocolNumber { - if ref := n.getRefForBroadcastRLocked(address); ref != nil { + if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil { return ref } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index b902c6ca9..0774b5382 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if isMulticastOrBroadcast(id.LocalAddress) { + if isInboundMulticastOrBroadcast(r) { mpep.handlePacketAll(r, id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return @@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) { eps.mu.RLock() destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() @@ -546,7 +546,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a TCP packet with a non-unicast source or destination // address, then do nothing further and instruct the caller to do the same. - if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) { + if protocol == header.TCPProtocolNumber && (!isInboundUnicast(r) || !isOutboundUnicast(r)) { // TCP can only be used to communicate between a single source and a // single destination; the addresses must be unicast. r.Stats().TCP.InvalidSegmentsReceived.Increment() @@ -677,10 +677,14 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN eps.mu.Unlock() } -func isMulticastOrBroadcast(addr tcpip.Address) bool { - return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) +func isInboundMulticastOrBroadcast(r *Route) bool { + return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) } -func isUnicast(addr tcpip.Address) bool { - return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr) +func isInboundUnicast(r *Route) bool { + return r.LocalAddress != header.IPv4Any && r.LocalAddress != header.IPv6Any && !isInboundMulticastOrBroadcast(r) +} + +func isOutboundUnicast(r *Route) bool { + return r.RemoteAddress != header.IPv4Any && r.RemoteAddress != header.IPv6Any && !r.IsOutboundBroadcast() && !header.IsV4MulticastAddress(r.RemoteAddress) && !header.IsV6MulticastAddress(r.RemoteAddress) } |