summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/network/arp/arp.go4
-rw-r--r--pkg/tcpip/network/ip_test.go8
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go14
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go20
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go8
-rw-r--r--pkg/tcpip/stack/nic.go21
-rw-r--r--pkg/tcpip/stack/registration.go11
-rw-r--r--pkg/tcpip/stack/stack.go12
8 files changed, 83 insertions, 15 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index bd9b9c020..5d7803537 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -145,7 +145,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
stats.requestsReceived.Increment()
localAddr := tcpip.Address(h.ProtocolAddressTarget())
- if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ if !e.nic.CheckLocalAddress(header.IPv4ProtocolNumber, localAddr) {
stats.requestsReceivedUnknownTargetAddress.Increment()
return // we have no useful answer, ignore the request
}
@@ -281,7 +281,7 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
localAddr = addr.Address
- } else if e.protocol.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ } else if !e.nic.CheckLocalAddress(header.IPv4ProtocolNumber, localAddr) {
stats.outgoingRequestBadLocalAddressErrors.Increment()
return &tcpip.ErrBadLocalAddress{}
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index a176ef2b9..90236ed9e 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -314,6 +314,10 @@ func (*testInterface) Promiscuous() bool {
return false
}
+func (*testInterface) Spoofing() bool {
+ return false
+}
+
func (t *testInterface) setEnabled(v bool) {
t.mu.Lock()
defer t.mu.Unlock()
@@ -332,6 +336,10 @@ func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tc
return nil
}
+func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
+ return false
+}
+
func TestSourceAddressValidation(t *testing.T) {
rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 74e70e283..2b7bc0dd0 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -120,6 +120,18 @@ func (*icmpv4FragmentationNeededSockError) Kind() stack.TransportErrorKind {
return stack.PacketTooBigTransportError
}
+func (e *endpoint) checkLocalAddress(addr tcpip.Address) bool {
+ if e.nic.Spoofing() {
+ return true
+ }
+
+ if addressEndpoint := e.AcquireAssignedAddress(addr, false, stack.NeverPrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ return true
+ }
+ return false
+}
+
// handleControl handles the case when an ICMP error packet contains the headers
// of the original packet that caused the ICMP one to be sent. This information
// is used to find out which transport endpoint must be notified about the ICMP
@@ -139,7 +151,7 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match an address we own.
srcAddr := hdr.SourceAddress()
- if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, srcAddr) == 0 {
+ if !e.checkLocalAddress(srcAddr) {
return
}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index dcfd93bab..edf4ef4e5 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -148,6 +148,18 @@ func (*icmpv6PacketTooBigSockError) Kind() stack.TransportErrorKind {
return stack.PacketTooBigTransportError
}
+func (e *endpoint) checkLocalAddress(addr tcpip.Address) bool {
+ if e.nic.Spoofing() {
+ return true
+ }
+
+ if addressEndpoint := e.AcquireAssignedAddress(addr, false, stack.NeverPrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ return true
+ }
+ return false
+}
+
// handleControl handles the case when an ICMP packet contains the headers of
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
@@ -165,8 +177,8 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
//
// Drop packet if it doesn't have the basic IPv6 header or if the
// original source address doesn't match an address we own.
- src := hdr.SourceAddress()
- if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
+ srcAddr := hdr.SourceAddress()
+ if !e.checkLocalAddress(srcAddr) {
return
}
@@ -192,7 +204,7 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
p = fragHdr.TransportProtocol()
}
- e.dispatcher.DeliverTransportError(src, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt)
+ e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt)
}
// getLinkAddrOption searches NDP options for a given link address option using
@@ -377,7 +389,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// section 5.4.3.
// Is the NS targeting us?
- if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
+ if !e.checkLocalAddress(targetAddr) {
return
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index ca46ec61f..4dbfb80da 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -124,6 +124,10 @@ func (*testInterface) Promiscuous() bool {
return false
}
+func (*testInterface) Spoofing() bool {
+ return false
+}
+
func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt)
}
@@ -149,6 +153,10 @@ func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber,
return nil
}
+func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
+ return false
+}
+
func handleICMPInIPv6(ep stack.NetworkEndpoint, src, dst tcpip.Address, icmp header.ICMPv6) {
ip := buffer.NewView(header.IPv6MinimumSize)
header.IPv6(ip).Encode(&header.IPv6Fields{
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 6f2a0e487..a90d027f2 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -441,6 +441,13 @@ func (n *NIC) setSpoofing(enable bool) {
n.mu.Unlock()
}
+// Spoofing implements NetworkInterface.
+func (n *NIC) Spoofing() bool {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+ return n.mu.spoofing
+}
+
// primaryAddress returns an address that can be used to communicate with
// remoteAddr.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint {
@@ -994,3 +1001,17 @@ func (n *NIC) HandleNeighborConfirmation(protocol tcpip.NetworkProtocolNumber, a
return &tcpip.ErrNotSupported{}
}
+
+// CheckLocalAddress implements NetworkInterface.
+func (n *NIC) CheckLocalAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ if n.Spoofing() {
+ return true
+ }
+
+ if addressEndpoint := n.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ return true
+ }
+
+ return false
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index d589f798d..2bc1c4270 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -514,8 +514,19 @@ type NetworkInterface interface {
Enabled() bool
// Promiscuous returns true if the interface is in promiscuous mode.
+ //
+ // When in promiscuous mode, the interface should accept all packets.
Promiscuous() bool
+ // Spoofing returns true if the interface is in spoofing mode.
+ //
+ // When in spoofing mode, the interface should consider all addresses as
+ // assigned to it.
+ Spoofing() bool
+
+ // CheckLocalAddress returns true if the address exists on the interface.
+ CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool
+
// WritePacketToRemote writes the packet to the given remote link address.
WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 035ab33ca..198e59c77 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1498,20 +1498,16 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto
return 0
}
- addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
- if addressEndpoint == nil {
- return 0
+ if nic.CheckLocalAddress(protocol, addr) {
+ return nic.id
}
- addressEndpoint.DecRef()
-
- return nic.id
+ return 0
}
// Go through all the NICs.
for _, nic := range s.nics {
- if addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint); addressEndpoint != nil {
- addressEndpoint.DecRef()
+ if nic.CheckLocalAddress(protocol, addr) {
return nic.id
}
}