diff options
Diffstat (limited to 'pkg/tcpip/stack/nic.go')
-rw-r--r-- | pkg/tcpip/stack/nic.go | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f21066fce..eaaf756cd 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -217,6 +217,11 @@ func (n *NIC) disableLocked() *tcpip.Error { } if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + // The NIC may have already left the multicast group. + if err := n.leaveGroupLocked(header.IPv4AllSystems, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + // The address may have already been removed. if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress { return err @@ -255,6 +260,13 @@ func (n *NIC) enable() *tcpip.Error { if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { return err } + + // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts + // multicast group. Note, the IANA calls the all-hosts multicast group the + // all-systems multicast group. + if err := n.joinGroupLocked(header.IPv4ProtocolNumber, header.IPv4AllSystems); err != nil { + return err + } } // Join the IPv6 All-Nodes Multicast group if the stack is configured to @@ -609,6 +621,9 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // If none exists a temporary one may be created if we are in promiscuous mode // or spoofing. Promiscuous mode will only be checked if promiscuous is true. // Similarly, spoofing will only be checked if spoofing is true. +// +// If the address is the IPv4 broadcast address for an endpoint's network, that +// endpoint will be returned. func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { n.mu.RLock() @@ -633,6 +648,16 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t } } + // Check if address is a broadcast address for the endpoint's network. + // + // Only IPv4 has a notion of broadcast addresses. + if protocol == header.IPv4ProtocolNumber { + if ref := n.getRefForBroadcastRLocked(address); ref != nil { + n.mu.RUnlock() + return ref + } + } + // A usable reference was not found, create a temporary one if requested by // the caller or if the address is found in the NIC's subnets. createTempEP := spoofingOrPromiscuous @@ -670,8 +695,34 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t return ref } +// getRefForBroadcastLocked returns an endpoint where address is the IPv4 +// broadcast address for the endpoint's network. +// +// n.mu MUST be read locked. +func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint { + for _, ref := range n.mu.endpoints { + // Only IPv4 has a notion of broadcast addresses. + if ref.protocol != header.IPv4ProtocolNumber { + continue + } + + addr := ref.addrWithPrefix() + subnet := addr.Subnet() + if subnet.IsBroadcast(address) && ref.tryIncRef() { + return ref + } + } + + return nil +} + /// getRefOrCreateTempLocked returns an existing endpoint for address or creates /// and returns a temporary endpoint. +// +// If the address is the IPv4 broadcast address for an endpoint's network, that +// endpoint will be returned. +// +// n.mu must be write locked. func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // No need to check the type as we are ok with expired endpoints at this @@ -685,6 +736,15 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add n.removeEndpointLocked(ref) } + // Check if address is a broadcast address for an endpoint's network. + // + // Only IPv4 has a notion of broadcast addresses. + if protocol == header.IPv4ProtocolNumber { + if ref := n.getRefForBroadcastRLocked(address); ref != nil { + return ref + } + } + // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { |