summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/stack/nic.go14
-rw-r--r--pkg/tcpip/tcpip.go10
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go52
3 files changed, 68 insertions, 8 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 53abf29e5..4afe7b744 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -984,7 +984,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
// hands the packet over for further processing. This function is called when
-// the NIC receives a packet from the physical interface.
+// the NIC receives a packet from the link endpoint.
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
@@ -1029,6 +1029,14 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
src, dst := netProto.ParseAddresses(pkt.Data.First())
+ if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
+ // The source address is one of our own, so we never should have gotten a
+ // packet like this unless handleLocal is false. Loopback also calls this
+ // function even though the packets didn't come from the physical interface
+ // so don't drop those.
+ n.stack.stats.IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
if ref := n.getRef(protocol, dst); ref != nil {
handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt)
return
@@ -1041,7 +1049,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
if n.stack.Forwarding() {
r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
if err != nil {
- n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
return
}
defer r.Release()
@@ -1079,7 +1087,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
// If a packet socket handled the packet, don't treat it as invalid.
if len(packetEPs) == 0 {
- n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
}
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index b7813cbc0..6243762e3 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -903,9 +903,13 @@ type IPStats struct {
// link layer in nic.DeliverNetworkPacket.
PacketsReceived *StatCounter
- // InvalidAddressesReceived is the total number of IP packets received
- // with an unknown or invalid destination address.
- InvalidAddressesReceived *StatCounter
+ // InvalidDestinationAddressesReceived is the total number of IP packets
+ // received with an unknown or invalid destination address.
+ InvalidDestinationAddressesReceived *StatCounter
+
+ // InvalidSourceAddressesReceived is the total number of IP packets received
+ // with a source address that should never have been received on the wire.
+ InvalidSourceAddressesReceived *StatCounter
// PacketsDelivered is the total number of incoming IP packets that
// are successfully delivered to the transport layer via HandlePacket.
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index ee9d10555..51bb61167 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -274,11 +274,16 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
-
- s := stack.New(stack.Options{
+ return newDualTestContextWithOptions(t, mtu, stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
})
+}
+
+func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext {
+ t.Helper()
+
+ s := stack.New(options)
ep := channel.New(256, mtu, "")
wep := stack.LinkEndpoint(ep)
@@ -763,6 +768,49 @@ func TestV6ReadOnV6(t *testing.T) {
testRead(c, unicastV6)
}
+// TestV4ReadSelfSource checks that packets coming from a local IP address are
+// correctly dropped when handleLocal is true and not otherwise.
+func TestV4ReadSelfSource(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ handleLocal bool
+ wantErr *tcpip.Error
+ wantInvalidSource uint64
+ }{
+ {"HandleLocal", false, nil, 0},
+ {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1},
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ HandleLocal: tt.handleLocal,
+ })
+ defer c.cleanup()
+
+ c.createEndpointForFlow(unicastV4)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ h.srcAddr = h.dstAddr
+
+ c.injectV4Packet(payload, &h, true /* valid */)
+
+ if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
+ t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
+ }
+
+ if _, _, err := c.ep.Read(nil); err != tt.wantErr {
+ t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
func TestV4ReadOnV4(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()