diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/route.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 50 |
2 files changed, 57 insertions, 11 deletions
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 53cb6694f..13b42d34b 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -71,22 +71,24 @@ type Route struct { // ownership of the provided local address. // // Returns an empty route if validation fails. -func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route { - addrWithPrefix := addressEndpoint.AddressWithPrefix() +func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route { + if len(localAddr) == 0 { + localAddr = addressEndpoint.AddressWithPrefix().Address + } - if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) { + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) { addressEndpoint.DecRef() return Route{} } // If no remote address is provided, use the local address. if len(remoteAddr) == 0 { - remoteAddr = addrWithPrefix.Address + remoteAddr = localAddr } r := makeRoute( netProto, - addrWithPrefix.Address, + localAddr, remoteAddr, outgoingNIC, localAddressNIC, @@ -99,7 +101,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp // broadcast it. if len(gateway) > 0 { r.NextHop = gateway - } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) { + } else if subnet := addressEndpoint.Subnet(); subnet.IsBroadcast(remoteAddr) { r.RemoteLinkAddress = header.EthernetBroadcastAddress } @@ -113,6 +115,10 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip panic(fmt.Sprintf("cannot create a route with NICs from different stacks")) } + if len(localAddr) == 0 { + localAddr = localAddressEndpoint.AddressWithPrefix().Address + } + loop := PacketOut // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index f4504e633..e788344d9 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1240,7 +1240,7 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re r := makeLocalRoute( netProto, - localAddressEndpoint.AddressWithPrefix().Address, + localAddr, remoteAddr, outgoingNIC, localAddressNIC, @@ -1317,7 +1317,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { return makeRoute( netProto, - addressEndpoint.AddressWithPrefix().Address, + localAddr, remoteAddr, nic, /* outboundNIC */ nic, /* localAddressNIC*/ @@ -1354,7 +1354,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if needRoute { gateway = route.Gateway } - r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop) + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) if r == (Route{}) { panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) } @@ -1391,7 +1391,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if id != 0 { if aNIC, ok := s.nics[id]; ok { if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil { - if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { return r, nil } } @@ -1409,7 +1409,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n continue } - if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { return r, nil } } @@ -2130,3 +2130,43 @@ func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber { } return protos } + +func isSubnetBroadcastOnNIC(nic *NIC, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + addressEndpoint := nic.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint) + if addressEndpoint == nil { + return false + } + + subnet := addressEndpoint.Subnet() + addressEndpoint.DecRef() + return subnet.IsBroadcast(addr) +} + +// IsSubnetBroadcast returns true if the provided address is a subnet-local +// broadcast address on the specified NIC and protocol. +// +// Returns false if the NIC is unknown or if the protocol is unknown or does +// not support addressing. +// +// If the NIC is not specified, the stack will check all NICs. +func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + s.mu.RLock() + defer s.mu.RUnlock() + + if nicID != 0 { + nic, ok := s.nics[nicID] + if !ok { + return false + } + + return isSubnetBroadcastOnNIC(nic, protocol, addr) + } + + for _, nic := range s.nics { + if isSubnetBroadcastOnNIC(nic, protocol, addr) { + return true + } + } + + return false +} |