From 2b457d9ee9ba50da4a9208d957053fac2c77932d Mon Sep 17 00:00:00 2001
From: Ghanan Gowripalan <ghanan@google.com>
Date: Thu, 13 May 2021 18:52:06 -0700
Subject: Check filter table when forwarding IP packets

This change updates the forwarding path to perform the forwarding hook
with iptables so that the filter table is consulted before a packet is
forwarded

Updates #170.

Test: iptables_test.TestForwardingHook
PiperOrigin-RevId: 373702359
---
 pkg/tcpip/network/internal/ip/stats.go       |  10 +
 pkg/tcpip/network/ipv4/ipv4.go               |  21 +-
 pkg/tcpip/network/ipv6/ipv6.go               |  22 +-
 pkg/tcpip/stack/iptables.go                  |   1 +
 pkg/tcpip/stack/iptables_types.go            |  15 +-
 pkg/tcpip/tcpip.go                           |   8 +
 pkg/tcpip/tests/integration/BUILD            |   2 +
 pkg/tcpip/tests/integration/iptables_test.go | 296 +++++++++++++++++++++++++++
 pkg/tcpip/tests/utils/utils.go               |  34 ++-
 9 files changed, 397 insertions(+), 12 deletions(-)

(limited to 'pkg/tcpip')

diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go
index 444515d40..0c2b62127 100644
--- a/pkg/tcpip/network/internal/ip/stats.go
+++ b/pkg/tcpip/network/internal/ip/stats.go
@@ -74,6 +74,10 @@ type MultiCounterIPStats struct {
 	// layer.
 	PacketsReceived tcpip.MultiCounterStat
 
+	// ValidPacketsReceived is the number of valid IP packets that reached the IP
+	// layer.
+	ValidPacketsReceived tcpip.MultiCounterStat
+
 	// DisabledPacketsReceived is the number of IP packets received from
 	// the link layer when the IP layer is disabled.
 	DisabledPacketsReceived tcpip.MultiCounterStat
@@ -114,6 +118,10 @@ type MultiCounterIPStats struct {
 	// Input chain.
 	IPTablesInputDropped tcpip.MultiCounterStat
 
+	// IPTablesForwardDropped is the number of IP packets dropped in the
+	// Forward chain.
+	IPTablesForwardDropped tcpip.MultiCounterStat
+
 	// IPTablesOutputDropped is the number of IP packets dropped in the
 	// Output chain.
 	IPTablesOutputDropped tcpip.MultiCounterStat
@@ -146,6 +154,7 @@ type MultiCounterIPStats struct {
 // Init sets internal counters to track a and b counters.
 func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
 	m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived)
+	m.ValidPacketsReceived.Init(a.ValidPacketsReceived, b.ValidPacketsReceived)
 	m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived)
 	m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived)
 	m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived)
@@ -156,6 +165,7 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
 	m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived)
 	m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped)
 	m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped)
+	m.IPTablesForwardDropped.Init(a.IPTablesForwardDropped, b.IPTablesForwardDropped)
 	m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped)
 	m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped)
 	m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived)
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index aef83e834..049811cbb 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -668,13 +668,23 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
 		}
 	}
 
+	stk := e.protocol.stack
+
 	// Check if the destination is owned by the stack.
 	if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
+		inNicName := stk.FindNICNameFromID(e.nic.ID())
+		outNicName := stk.FindNICNameFromID(ep.nic.ID())
+		if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+			// iptables is telling us to drop the packet.
+			e.stats.ip.IPTablesForwardDropped.Increment()
+			return nil
+		}
+
 		ep.handleValidatedPacket(h, pkt)
 		return nil
 	}
 
-	r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+	r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
 	switch err.(type) {
 	case nil:
 	case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable:
@@ -688,6 +698,14 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
 	}
 	defer r.Release()
 
+	inNicName := stk.FindNICNameFromID(e.nic.ID())
+	outNicName := stk.FindNICNameFromID(r.NICID())
+	if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+		// iptables is telling us to drop the packet.
+		e.stats.ip.IPTablesForwardDropped.Increment()
+		return nil
+	}
+
 	// We need to do a deep copy of the IP packet because
 	// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
 	// not own it.
@@ -803,6 +821,7 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
 func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) {
 	pkt.NICID = e.nic.ID()
 	stats := e.stats
+	stats.ip.ValidPacketsReceived.Increment()
 
 	srcAddr := h.SourceAddress()
 	dstAddr := h.DestinationAddress()
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index febbb3f38..f0e06f86b 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -941,8 +941,18 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
 		return &ip.ErrTTLExceeded{}
 	}
 
+	stk := e.protocol.stack
+
 	// Check if the destination is owned by the stack.
 	if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
+		inNicName := stk.FindNICNameFromID(e.nic.ID())
+		outNicName := stk.FindNICNameFromID(ep.nic.ID())
+		if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+			// iptables is telling us to drop the packet.
+			e.stats.ip.IPTablesForwardDropped.Increment()
+			return nil
+		}
+
 		ep.handleValidatedPacket(h, pkt)
 		return nil
 	}
@@ -952,7 +962,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
 		return &ip.ErrParameterProblem{}
 	}
 
-	r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+	r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
 	switch err.(type) {
 	case nil:
 	case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable:
@@ -965,6 +975,14 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
 	}
 	defer r.Release()
 
+	inNicName := stk.FindNICNameFromID(e.nic.ID())
+	outNicName := stk.FindNICNameFromID(r.NICID())
+	if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+		// iptables is telling us to drop the packet.
+		e.stats.ip.IPTablesForwardDropped.Increment()
+		return nil
+	}
+
 	// We need to do a deep copy of the IP packet because
 	// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
 	// not own it.
@@ -1073,6 +1091,8 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
 func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) {
 	pkt.NICID = e.nic.ID()
 	stats := e.stats.ip
+	stats.ValidPacketsReceived.Increment()
+
 	srcAddr := h.SourceAddress()
 	dstAddr := h.DestinationAddress()
 
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index e2894c548..3670d5995 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -177,6 +177,7 @@ func DefaultTables() *IPTables {
 		priorities: [NumHooks][]TableID{
 			Prerouting:  {MangleID, NATID},
 			Input:       {NATID, FilterID},
+			Forward:     {FilterID},
 			Output:      {MangleID, NATID, FilterID},
 			Postrouting: {MangleID, NATID},
 		},
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 4631ab93f..93592e7f5 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -280,9 +280,18 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa
 		return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert)
 	case Output:
 		return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert)
-	case Forward, Postrouting:
-		// TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
-		// hooks after supported.
+	case Forward:
+		if !matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) {
+			return false
+		}
+
+		if !matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) {
+			return false
+		}
+
+		return true
+	case Postrouting:
+		// TODO(gvisor.dev/issue/170): Add the check for POSTROUTING.
 		return true
 	default:
 		panic(fmt.Sprintf("unknown hook: %d", hook))
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 7b9c8cd4f..797778e08 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1571,6 +1571,10 @@ type IPStats struct {
 	// PacketsReceived is the number of IP packets received from the link layer.
 	PacketsReceived *StatCounter
 
+	// ValidPacketsReceived is the number of valid IP packets that reached the IP
+	// layer.
+	ValidPacketsReceived *StatCounter
+
 	// DisabledPacketsReceived is the number of IP packets received from the link
 	// layer when the IP layer is disabled.
 	DisabledPacketsReceived *StatCounter
@@ -1610,6 +1614,10 @@ type IPStats struct {
 	// chain.
 	IPTablesInputDropped *StatCounter
 
+	// IPTablesForwardDropped is the number of IP packets dropped in the Forward
+	// chain.
+	IPTablesForwardDropped *StatCounter
+
 	// IPTablesOutputDropped is the number of IP packets dropped in the Output
 	// chain.
 	IPTablesOutputDropped *StatCounter
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index d4f7bb5ff..ab2dab60c 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -31,12 +31,14 @@ go_test(
     deps = [
         "//pkg/tcpip",
         "//pkg/tcpip/buffer",
+        "//pkg/tcpip/checker",
         "//pkg/tcpip/header",
         "//pkg/tcpip/link/channel",
         "//pkg/tcpip/network/ipv4",
         "//pkg/tcpip/network/ipv6",
         "//pkg/tcpip/stack",
         "//pkg/tcpip/tests/utils",
+        "//pkg/tcpip/testutil",
         "//pkg/tcpip/transport/udp",
     ],
 )
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index c61d4e788..07ba2b837 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -19,12 +19,14 @@ import (
 
 	"gvisor.dev/gvisor/pkg/tcpip"
 	"gvisor.dev/gvisor/pkg/tcpip/buffer"
+	"gvisor.dev/gvisor/pkg/tcpip/checker"
 	"gvisor.dev/gvisor/pkg/tcpip/header"
 	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
 	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
 	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
 	"gvisor.dev/gvisor/pkg/tcpip/stack"
 	"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+	"gvisor.dev/gvisor/pkg/tcpip/testutil"
 	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
 )
 
@@ -645,3 +647,297 @@ func TestIPTableWritePackets(t *testing.T) {
 		})
 	}
 }
+
+const ttl = 64
+
+var (
+	ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
+	ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
+)
+
+func rxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
+	utils.RxICMPv4EchoReply(e, src, dst, ttl)
+}
+
+func rxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
+	utils.RxICMPv6EchoReply(e, src, dst, ttl)
+}
+
+func forwardedICMPv4EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+	checker.IPv4(t, b,
+		checker.SrcAddr(src),
+		checker.DstAddr(dst),
+		checker.TTL(ttl-1),
+		checker.ICMPv4(
+			checker.ICMPv4Type(header.ICMPv4EchoReply)))
+}
+
+func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+	checker.IPv6(t, b,
+		checker.SrcAddr(src),
+		checker.DstAddr(dst),
+		checker.TTL(ttl-1),
+		checker.ICMPv6(
+			checker.ICMPv6Type(header.ICMPv6EchoReply)))
+}
+
+func TestForwardingHook(t *testing.T) {
+	const (
+		nicID1 = 1
+		nicID2 = 2
+
+		nic1Name = "nic1"
+		nic2Name = "nic2"
+
+		otherNICName = "otherNIC"
+	)
+
+	tests := []struct {
+		name             string
+		netProto         tcpip.NetworkProtocolNumber
+		local            bool
+		srcAddr, dstAddr tcpip.Address
+		rx               func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+		checker          func(*testing.T, []byte)
+	}{
+		{
+			name:     "IPv4 remote",
+			netProto: ipv4.ProtocolNumber,
+			local:    false,
+			srcAddr:  utils.RemoteIPv4Addr,
+			dstAddr:  utils.Ipv4Addr2.AddressWithPrefix.Address,
+			rx:       rxICMPv4EchoReply,
+			checker: func(t *testing.T, b []byte) {
+				forwardedICMPv4EchoReplyChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+			},
+		},
+		{
+			name:     "IPv4 local",
+			netProto: ipv4.ProtocolNumber,
+			local:    true,
+			srcAddr:  utils.RemoteIPv4Addr,
+			dstAddr:  utils.Ipv4Addr.Address,
+			rx:       rxICMPv4EchoReply,
+		},
+		{
+			name:     "IPv6 remote",
+			netProto: ipv6.ProtocolNumber,
+			local:    false,
+			srcAddr:  utils.RemoteIPv6Addr,
+			dstAddr:  utils.Ipv6Addr2.AddressWithPrefix.Address,
+			rx:       rxICMPv6EchoReply,
+			checker: func(t *testing.T, b []byte) {
+				forwardedICMPv6EchoReplyChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+			},
+		},
+		{
+			name:     "IPv6 local",
+			netProto: ipv6.ProtocolNumber,
+			local:    true,
+			srcAddr:  utils.RemoteIPv6Addr,
+			dstAddr:  utils.Ipv6Addr.Address,
+			rx:       rxICMPv6EchoReply,
+		},
+	}
+
+	setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
+		return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
+			t.Helper()
+
+			ipv6 := netProto == ipv6.ProtocolNumber
+
+			ipt := s.IPTables()
+			filter := ipt.GetTable(stack.FilterID, ipv6)
+			ruleIdx := filter.BuiltinChains[stack.Forward]
+			filter.Rules[ruleIdx].Filter = f
+			filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
+			// Make sure the packet is not dropped by the next rule.
+			filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
+			if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
+				t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
+			}
+		}
+	}
+
+	boolToInt := func(v bool) uint64 {
+		if v {
+			return 1
+		}
+		return 0
+	}
+
+	subTests := []struct {
+		name          string
+		setupFilter   func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
+		expectForward bool
+	}{
+		{
+			name:          "Accept",
+			setupFilter:   func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
+			expectForward: true,
+		},
+
+		{
+			name:          "Drop",
+			setupFilter:   setupDropFilter(stack.IPHeaderFilter{}),
+			expectForward: false,
+		},
+		{
+			name:          "Drop with input NIC filtering",
+			setupFilter:   setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}),
+			expectForward: false,
+		},
+		{
+			name:          "Drop with output NIC filtering",
+			setupFilter:   setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}),
+			expectForward: false,
+		},
+		{
+			name:          "Drop with input and output NIC filtering",
+			setupFilter:   setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
+			expectForward: false,
+		},
+
+		{
+			name:          "Drop with other input NIC filtering",
+			setupFilter:   setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}),
+			expectForward: true,
+		},
+		{
+			name:          "Drop with other output NIC filtering",
+			setupFilter:   setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}),
+			expectForward: true,
+		},
+		{
+			name:          "Drop with other input and output NIC filtering",
+			setupFilter:   setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
+			expectForward: true,
+		},
+		{
+			name:          "Drop with input and other output NIC filtering",
+			setupFilter:   setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
+			expectForward: true,
+		},
+		{
+			name:          "Drop with other input and other output NIC filtering",
+			setupFilter:   setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
+			expectForward: true,
+		},
+
+		{
+			name:          "Drop with inverted input NIC filtering",
+			setupFilter:   setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
+			expectForward: true,
+		},
+		{
+			name:          "Drop with inverted output NIC filtering",
+			setupFilter:   setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
+			expectForward: true,
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			for _, subTest := range subTests {
+				t.Run(subTest.name, func(t *testing.T) {
+					s := stack.New(stack.Options{
+						NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+					})
+
+					subTest.setupFilter(t, s, test.netProto)
+
+					e1 := channel.New(1, header.IPv6MinimumMTU, "")
+					if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
+						t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
+					}
+
+					e2 := channel.New(1, header.IPv6MinimumMTU, "")
+					if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
+						t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
+					}
+
+					if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
+						t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+					}
+					if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
+						t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+					}
+
+					if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
+						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
+					}
+					if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
+						t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
+					}
+
+					s.SetRouteTable([]tcpip.Route{
+						{
+							Destination: header.IPv4EmptySubnet,
+							NIC:         nicID2,
+						},
+						{
+							Destination: header.IPv6EmptySubnet,
+							NIC:         nicID2,
+						},
+					})
+
+					test.rx(e1, test.srcAddr, test.dstAddr)
+
+					expectTransmitPacket := subTest.expectForward && !test.local
+
+					ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
+					if err != nil {
+						t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
+					}
+					ep1Stats := ep1.Stats()
+					ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
+					if !ok {
+						t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
+					}
+					ip1Stats := ipEP1Stats.IPStats()
+
+					if got := ip1Stats.PacketsReceived.Value(); got != 1 {
+						t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
+					}
+					if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
+						t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
+					}
+					if got, want := ip1Stats.IPTablesForwardDropped.Value(), boolToInt(!subTest.expectForward); got != want {
+						t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, want)
+					}
+					if got := ip1Stats.PacketsSent.Value(); got != 0 {
+						t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = 0", got)
+					}
+
+					ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
+					if err != nil {
+						t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
+					}
+					ep2Stats := ep2.Stats()
+					ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
+					if !ok {
+						t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
+					}
+					ip2Stats := ipEP2Stats.IPStats()
+					if got := ip2Stats.PacketsReceived.Value(); got != 0 {
+						t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
+					}
+					if got, want := ip2Stats.ValidPacketsReceived.Value(), boolToInt(subTest.expectForward && test.local); got != want {
+						t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, want)
+					}
+					if got, want := ip2Stats.PacketsSent.Value(), boolToInt(expectTransmitPacket); got != want {
+						t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = %d", got, want)
+					}
+
+					p, ok := e2.Read()
+					if ok != expectTransmitPacket {
+						t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, expectTransmitPacket)
+					}
+					if expectTransmitPacket {
+						test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+					}
+				})
+			}
+		})
+	}
+}
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index c8b9c9b5c..2e6ae55ea 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -316,13 +316,11 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
 	})
 }
 
-// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on
-// the provided endpoint.
-func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+func rxICMPv4Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv4Type) {
 	totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
 	hdr := buffer.NewPrependable(totalLen)
 	pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
-	pkt.SetType(header.ICMPv4Echo)
+	pkt.SetType(ty)
 	pkt.SetCode(header.ICMPv4UnusedCode)
 	pkt.SetChecksum(0)
 	pkt.SetChecksum(^header.Checksum(pkt, 0))
@@ -341,13 +339,23 @@ func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8)
 	}))
 }
 
-// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on
+// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on
 // the provided endpoint.
-func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+	rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4Echo)
+}
+
+// RxICMPv4EchoReply constructs and injects an ICMPv4 echo reply packet on
+// the provided endpoint.
+func RxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+	rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4EchoReply)
+}
+
+func rxICMPv6Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv6Type) {
 	totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
 	hdr := buffer.NewPrependable(totalLen)
 	pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
-	pkt.SetType(header.ICMPv6EchoRequest)
+	pkt.SetType(ty)
 	pkt.SetCode(header.ICMPv6UnusedCode)
 	pkt.SetChecksum(0)
 	pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
@@ -368,3 +376,15 @@ func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8)
 		Data: hdr.View().ToVectorisedView(),
 	}))
 }
+
+// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on
+// the provided endpoint.
+func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+	rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoRequest)
+}
+
+// RxICMPv6EchoReply constructs and injects an ICMPv6 echo reply packet on
+// the provided endpoint.
+func RxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+	rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoReply)
+}
-- 
cgit v1.2.3