diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-02-06 09:07:26 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-02-06 09:09:19 -0800 |
commit | c19e049f2c79ee9864cc273f6dc714b5caa434ca (patch) | |
tree | b36a569a4ce155548d75874b5237ada9792953f7 | |
parent | 83b764d9d2193e2e01f3a60792f3468c1843c5a8 (diff) |
Check local address directly through NIC
Network endpoints that wish to check addresses on another NIC-local
network endpoint may now do so through the NetworkInterface.
This fixes a lock ordering issue between NIC removal and link
resolution. Before this change:
NIC Removal takes the stack lock, neighbor cache lock then neighbor
entries' locks.
When performing IPv4 link resolution, we take the entry lock then ARP
would try check IPv4 local addresses through the stack which tries to
obtain the stack's lock.
Now that ARP can check IPv4 addreses through the NIC, we avoid the lock
ordering issue, while also removing the need for stack to lookup the
NIC.
PiperOrigin-RevId: 356034245
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 21 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 12 |
8 files changed, 83 insertions, 15 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index bd9b9c020..5d7803537 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -145,7 +145,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { stats.requestsReceived.Increment() localAddr := tcpip.Address(h.ProtocolAddressTarget()) - if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + if !e.nic.CheckLocalAddress(header.IPv4ProtocolNumber, localAddr) { stats.requestsReceivedUnknownTargetAddress.Increment() return // we have no useful answer, ignore the request } @@ -281,7 +281,7 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } localAddr = addr.Address - } else if e.protocol.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + } else if !e.nic.CheckLocalAddress(header.IPv4ProtocolNumber, localAddr) { stats.outgoingRequestBadLocalAddressErrors.Increment() return &tcpip.ErrBadLocalAddress{} } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index a176ef2b9..90236ed9e 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -314,6 +314,10 @@ func (*testInterface) Promiscuous() bool { return false } +func (*testInterface) Spoofing() bool { + return false +} + func (t *testInterface) setEnabled(v bool) { t.mu.Lock() defer t.mu.Unlock() @@ -332,6 +336,10 @@ func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tc return nil } +func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { + return false +} + func TestSourceAddressValidation(t *testing.T) { rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 74e70e283..2b7bc0dd0 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -120,6 +120,18 @@ func (*icmpv4FragmentationNeededSockError) Kind() stack.TransportErrorKind { return stack.PacketTooBigTransportError } +func (e *endpoint) checkLocalAddress(addr tcpip.Address) bool { + if e.nic.Spoofing() { + return true + } + + if addressEndpoint := e.AcquireAssignedAddress(addr, false, stack.NeverPrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + return true + } + return false +} + // handleControl handles the case when an ICMP error packet contains the headers // of the original packet that caused the ICMP one to be sent. This information // is used to find out which transport endpoint must be notified about the ICMP @@ -139,7 +151,7 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match an address we own. srcAddr := hdr.SourceAddress() - if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, srcAddr) == 0 { + if !e.checkLocalAddress(srcAddr) { return } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index dcfd93bab..edf4ef4e5 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -148,6 +148,18 @@ func (*icmpv6PacketTooBigSockError) Kind() stack.TransportErrorKind { return stack.PacketTooBigTransportError } +func (e *endpoint) checkLocalAddress(addr tcpip.Address) bool { + if e.nic.Spoofing() { + return true + } + + if addressEndpoint := e.AcquireAssignedAddress(addr, false, stack.NeverPrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + return true + } + return false +} + // handleControl handles the case when an ICMP packet contains the headers of // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP @@ -165,8 +177,8 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match an address we own. - src := hdr.SourceAddress() - if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { + srcAddr := hdr.SourceAddress() + if !e.checkLocalAddress(srcAddr) { return } @@ -192,7 +204,7 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe p = fragHdr.TransportProtocol() } - e.dispatcher.DeliverTransportError(src, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt) + e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt) } // getLinkAddrOption searches NDP options for a given link address option using @@ -377,7 +389,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // section 5.4.3. // Is the NS targeting us? - if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { + if !e.checkLocalAddress(targetAddr) { return } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index ca46ec61f..4dbfb80da 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -124,6 +124,10 @@ func (*testInterface) Promiscuous() bool { return false } +func (*testInterface) Spoofing() bool { + return false +} + func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt) } @@ -149,6 +153,10 @@ func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, return nil } +func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { + return false +} + func handleICMPInIPv6(ep stack.NetworkEndpoint, src, dst tcpip.Address, icmp header.ICMPv6) { ip := buffer.NewView(header.IPv6MinimumSize) header.IPv6(ip).Encode(&header.IPv6Fields{ diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 6f2a0e487..a90d027f2 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -441,6 +441,13 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } +// Spoofing implements NetworkInterface. +func (n *NIC) Spoofing() bool { + n.mu.RLock() + defer n.mu.RUnlock() + return n.mu.spoofing +} + // primaryAddress returns an address that can be used to communicate with // remoteAddr. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { @@ -994,3 +1001,17 @@ func (n *NIC) HandleNeighborConfirmation(protocol tcpip.NetworkProtocolNumber, a return &tcpip.ErrNotSupported{} } + +// CheckLocalAddress implements NetworkInterface. +func (n *NIC) CheckLocalAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + if n.Spoofing() { + return true + } + + if addressEndpoint := n.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + return true + } + + return false +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index d589f798d..2bc1c4270 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -514,8 +514,19 @@ type NetworkInterface interface { Enabled() bool // Promiscuous returns true if the interface is in promiscuous mode. + // + // When in promiscuous mode, the interface should accept all packets. Promiscuous() bool + // Spoofing returns true if the interface is in spoofing mode. + // + // When in spoofing mode, the interface should consider all addresses as + // assigned to it. + Spoofing() bool + + // CheckLocalAddress returns true if the address exists on the interface. + CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool + // WritePacketToRemote writes the packet to the given remote link address. WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 035ab33ca..198e59c77 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1498,20 +1498,16 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto return 0 } - addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) - if addressEndpoint == nil { - return 0 + if nic.CheckLocalAddress(protocol, addr) { + return nic.id } - addressEndpoint.DecRef() - - return nic.id + return 0 } // Go through all the NICs. for _, nic := range s.nics { - if addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint); addressEndpoint != nil { - addressEndpoint.DecRef() + if nic.CheckLocalAddress(protocol, addr) { return nic.id } } |