diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/route.go | 23 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 18 |
3 files changed, 23 insertions, 22 deletions
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index c2eabde9e..2cbbf0de8 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -48,10 +48,6 @@ type Route struct { // Loop controls where WritePacket should send packets. Loop PacketLooping - - // directedBroadcast indicates whether this route is sending a directed - // broadcast packet. - directedBroadcast bool } // makeRoute initializes a new route. It takes ownership of the provided @@ -303,24 +299,27 @@ func (r *Route) Stack() *Stack { return r.ref.stack() } +func (r *Route) isV4Broadcast(addr tcpip.Address) bool { + if addr == header.IPv4Broadcast { + return true + } + + subnet := r.ref.addrWithPrefix().Subnet() + return subnet.IsBroadcast(addr) +} + // IsOutboundBroadcast returns true if the route is for an outbound broadcast // packet. func (r *Route) IsOutboundBroadcast() bool { // Only IPv4 has a notion of broadcast. - return r.directedBroadcast || r.RemoteAddress == header.IPv4Broadcast + return r.isV4Broadcast(r.RemoteAddress) } // IsInboundBroadcast returns true if the route is for an inbound broadcast // packet. func (r *Route) IsInboundBroadcast() bool { // Only IPv4 has a notion of broadcast. - if r.LocalAddress == header.IPv4Broadcast { - return true - } - - addr := r.ref.addrWithPrefix() - subnet := addr.Subnet() - return subnet.IsBroadcast(r.LocalAddress) + return r.isV4Broadcast(r.LocalAddress) } // ReverseRoute returns new route with given source and destination address. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index def8b0b43..6a683545d 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1311,13 +1311,11 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } r := makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) - r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr) - if len(route.Gateway) > 0 { if needRoute { r.NextHop = route.Gateway } - } else if r.directedBroadcast { + } else if subnet := ref.addrWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { r.RemoteLinkAddress = header.EthernetBroadcastAddress } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index b902c6ca9..0774b5382 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if isMulticastOrBroadcast(id.LocalAddress) { + if isInboundMulticastOrBroadcast(r) { mpep.handlePacketAll(r, id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return @@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) { eps.mu.RLock() destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() @@ -546,7 +546,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a TCP packet with a non-unicast source or destination // address, then do nothing further and instruct the caller to do the same. - if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) { + if protocol == header.TCPProtocolNumber && (!isInboundUnicast(r) || !isOutboundUnicast(r)) { // TCP can only be used to communicate between a single source and a // single destination; the addresses must be unicast. r.Stats().TCP.InvalidSegmentsReceived.Increment() @@ -677,10 +677,14 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN eps.mu.Unlock() } -func isMulticastOrBroadcast(addr tcpip.Address) bool { - return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) +func isInboundMulticastOrBroadcast(r *Route) bool { + return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) } -func isUnicast(addr tcpip.Address) bool { - return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr) +func isInboundUnicast(r *Route) bool { + return r.LocalAddress != header.IPv4Any && r.LocalAddress != header.IPv6Any && !isInboundMulticastOrBroadcast(r) +} + +func isOutboundUnicast(r *Route) bool { + return r.RemoteAddress != header.IPv4Any && r.RemoteAddress != header.IPv6Any && !r.IsOutboundBroadcast() && !header.IsV4MulticastAddress(r.RemoteAddress) && !header.IsV6MulticastAddress(r.RemoteAddress) } |