summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-05-13 18:52:06 -0700
committergVisor bot <gvisor-bot@google.com>2021-05-13 18:54:09 -0700
commit2b457d9ee9ba50da4a9208d957053fac2c77932d (patch)
tree8c711a691ef53a677083af1f884d664284f30f9c /pkg/tcpip
parent7ea2dcbaece00b5c7310c74fcf99c1fb32e9ec28 (diff)
Check filter table when forwarding IP packets
This change updates the forwarding path to perform the forwarding hook with iptables so that the filter table is consulted before a packet is forwarded Updates #170. Test: iptables_test.TestForwardingHook PiperOrigin-RevId: 373702359
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/network/internal/ip/stats.go10
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go21
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go22
-rw-r--r--pkg/tcpip/stack/iptables.go1
-rw-r--r--pkg/tcpip/stack/iptables_types.go15
-rw-r--r--pkg/tcpip/tcpip.go8
-rw-r--r--pkg/tcpip/tests/integration/BUILD2
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go296
-rw-r--r--pkg/tcpip/tests/utils/utils.go34
9 files changed, 397 insertions, 12 deletions
diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go
index 444515d40..0c2b62127 100644
--- a/pkg/tcpip/network/internal/ip/stats.go
+++ b/pkg/tcpip/network/internal/ip/stats.go
@@ -74,6 +74,10 @@ type MultiCounterIPStats struct {
// layer.
PacketsReceived tcpip.MultiCounterStat
+ // ValidPacketsReceived is the number of valid IP packets that reached the IP
+ // layer.
+ ValidPacketsReceived tcpip.MultiCounterStat
+
// DisabledPacketsReceived is the number of IP packets received from
// the link layer when the IP layer is disabled.
DisabledPacketsReceived tcpip.MultiCounterStat
@@ -114,6 +118,10 @@ type MultiCounterIPStats struct {
// Input chain.
IPTablesInputDropped tcpip.MultiCounterStat
+ // IPTablesForwardDropped is the number of IP packets dropped in the
+ // Forward chain.
+ IPTablesForwardDropped tcpip.MultiCounterStat
+
// IPTablesOutputDropped is the number of IP packets dropped in the
// Output chain.
IPTablesOutputDropped tcpip.MultiCounterStat
@@ -146,6 +154,7 @@ type MultiCounterIPStats struct {
// Init sets internal counters to track a and b counters.
func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived)
+ m.ValidPacketsReceived.Init(a.ValidPacketsReceived, b.ValidPacketsReceived)
m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived)
m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived)
m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived)
@@ -156,6 +165,7 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived)
m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped)
m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped)
+ m.IPTablesForwardDropped.Init(a.IPTablesForwardDropped, b.IPTablesForwardDropped)
m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped)
m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped)
m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived)
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index aef83e834..049811cbb 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -668,13 +668,23 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
}
}
+ stk := e.protocol.stack
+
// Check if the destination is owned by the stack.
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
+ inNicName := stk.FindNICNameFromID(e.nic.ID())
+ outNicName := stk.FindNICNameFromID(ep.nic.ID())
+ if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesForwardDropped.Increment()
+ return nil
+ }
+
ep.handleValidatedPacket(h, pkt)
return nil
}
- r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+ r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
switch err.(type) {
case nil:
case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable:
@@ -688,6 +698,14 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
}
defer r.Release()
+ inNicName := stk.FindNICNameFromID(e.nic.ID())
+ outNicName := stk.FindNICNameFromID(r.NICID())
+ if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesForwardDropped.Increment()
+ return nil
+ }
+
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
@@ -803,6 +821,7 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.stats
+ stats.ip.ValidPacketsReceived.Increment()
srcAddr := h.SourceAddress()
dstAddr := h.DestinationAddress()
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index febbb3f38..f0e06f86b 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -941,8 +941,18 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
return &ip.ErrTTLExceeded{}
}
+ stk := e.protocol.stack
+
// Check if the destination is owned by the stack.
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
+ inNicName := stk.FindNICNameFromID(e.nic.ID())
+ outNicName := stk.FindNICNameFromID(ep.nic.ID())
+ if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesForwardDropped.Increment()
+ return nil
+ }
+
ep.handleValidatedPacket(h, pkt)
return nil
}
@@ -952,7 +962,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
return &ip.ErrParameterProblem{}
}
- r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+ r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
switch err.(type) {
case nil:
case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable:
@@ -965,6 +975,14 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
}
defer r.Release()
+ inNicName := stk.FindNICNameFromID(e.nic.ID())
+ outNicName := stk.FindNICNameFromID(r.NICID())
+ if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesForwardDropped.Increment()
+ return nil
+ }
+
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
@@ -1073,6 +1091,8 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.stats.ip
+ stats.ValidPacketsReceived.Increment()
+
srcAddr := h.SourceAddress()
dstAddr := h.DestinationAddress()
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index e2894c548..3670d5995 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -177,6 +177,7 @@ func DefaultTables() *IPTables {
priorities: [NumHooks][]TableID{
Prerouting: {MangleID, NATID},
Input: {NATID, FilterID},
+ Forward: {FilterID},
Output: {MangleID, NATID, FilterID},
Postrouting: {MangleID, NATID},
},
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 4631ab93f..93592e7f5 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -280,9 +280,18 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa
return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert)
case Output:
return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert)
- case Forward, Postrouting:
- // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
- // hooks after supported.
+ case Forward:
+ if !matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) {
+ return false
+ }
+
+ if !matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) {
+ return false
+ }
+
+ return true
+ case Postrouting:
+ // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING.
return true
default:
panic(fmt.Sprintf("unknown hook: %d", hook))
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 7b9c8cd4f..797778e08 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1571,6 +1571,10 @@ type IPStats struct {
// PacketsReceived is the number of IP packets received from the link layer.
PacketsReceived *StatCounter
+ // ValidPacketsReceived is the number of valid IP packets that reached the IP
+ // layer.
+ ValidPacketsReceived *StatCounter
+
// DisabledPacketsReceived is the number of IP packets received from the link
// layer when the IP layer is disabled.
DisabledPacketsReceived *StatCounter
@@ -1610,6 +1614,10 @@ type IPStats struct {
// chain.
IPTablesInputDropped *StatCounter
+ // IPTablesForwardDropped is the number of IP packets dropped in the Forward
+ // chain.
+ IPTablesForwardDropped *StatCounter
+
// IPTablesOutputDropped is the number of IP packets dropped in the Output
// chain.
IPTablesOutputDropped *StatCounter
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index d4f7bb5ff..ab2dab60c 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -31,12 +31,14 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/udp",
],
)
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index c61d4e788..07ba2b837 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -19,12 +19,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
@@ -645,3 +647,297 @@ func TestIPTableWritePackets(t *testing.T) {
})
}
}
+
+const ttl = 64
+
+var (
+ ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
+ ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
+)
+
+func rxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv4EchoReply(e, src, dst, ttl)
+}
+
+func rxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv6EchoReply(e, src, dst, ttl)
+}
+
+func forwardedICMPv4EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply)))
+}
+
+func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply)))
+}
+
+func TestForwardingHook(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ nic1Name = "nic1"
+ nic2Name = "nic2"
+
+ otherNICName = "otherNIC"
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ local bool
+ srcAddr, dstAddr tcpip.Address
+ rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ checker func(*testing.T, []byte)
+ }{
+ {
+ name: "IPv4 remote",
+ netProto: ipv4.ProtocolNumber,
+ local: false,
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv4EchoReply,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv4EchoReplyChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv4 local",
+ netProto: ipv4.ProtocolNumber,
+ local: true,
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4Addr.Address,
+ rx: rxICMPv4EchoReply,
+ },
+ {
+ name: "IPv6 remote",
+ netProto: ipv6.ProtocolNumber,
+ local: false,
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv6EchoReply,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv6EchoReplyChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv6 local",
+ netProto: ipv6.ProtocolNumber,
+ local: true,
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr.Address,
+ rx: rxICMPv6EchoReply,
+ },
+ }
+
+ setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
+ return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, ipv6)
+ ruleIdx := filter.BuiltinChains[stack.Forward]
+ filter.Rules[ruleIdx].Filter = f
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
+ }
+ }
+ }
+
+ boolToInt := func(v bool) uint64 {
+ if v {
+ return 1
+ }
+ return 0
+ }
+
+ subTests := []struct {
+ name string
+ setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
+ expectForward bool
+ }{
+ {
+ name: "Accept",
+ setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
+ expectForward: true,
+ },
+
+ {
+ name: "Drop",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{}),
+ expectForward: false,
+ },
+ {
+ name: "Drop with input NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}),
+ expectForward: false,
+ },
+ {
+ name: "Drop with output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}),
+ expectForward: false,
+ },
+ {
+ name: "Drop with input and output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
+ expectForward: false,
+ },
+
+ {
+ name: "Drop with other input NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}),
+ expectForward: true,
+ },
+ {
+ name: "Drop with other output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}),
+ expectForward: true,
+ },
+ {
+ name: "Drop with other input and output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
+ expectForward: true,
+ },
+ {
+ name: "Drop with input and other output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
+ expectForward: true,
+ },
+ {
+ name: "Drop with other input and other output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
+ expectForward: true,
+ },
+
+ {
+ name: "Drop with inverted input NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
+ expectForward: true,
+ },
+ {
+ name: "Drop with inverted output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
+ expectForward: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ })
+
+ subTest.setupFilter(t, s, test.netProto)
+
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
+ }
+
+ if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ }
+ if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ }
+
+ if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID2,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID2,
+ },
+ })
+
+ test.rx(e1, test.srcAddr, test.dstAddr)
+
+ expectTransmitPacket := subTest.expectForward && !test.local
+
+ ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
+ }
+ ep1Stats := ep1.Stats()
+ ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
+ }
+ ip1Stats := ipEP1Stats.IPStats()
+
+ if got := ip1Stats.PacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got, want := ip1Stats.IPTablesForwardDropped.Value(), boolToInt(!subTest.expectForward); got != want {
+ t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, want)
+ }
+ if got := ip1Stats.PacketsSent.Value(); got != 0 {
+ t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = 0", got)
+ }
+
+ ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
+ }
+ ep2Stats := ep2.Stats()
+ ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
+ }
+ ip2Stats := ipEP2Stats.IPStats()
+ if got := ip2Stats.PacketsReceived.Value(); got != 0 {
+ t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
+ }
+ if got, want := ip2Stats.ValidPacketsReceived.Value(), boolToInt(subTest.expectForward && test.local); got != want {
+ t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, want)
+ }
+ if got, want := ip2Stats.PacketsSent.Value(), boolToInt(expectTransmitPacket); got != want {
+ t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = %d", got, want)
+ }
+
+ p, ok := e2.Read()
+ if ok != expectTransmitPacket {
+ t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, expectTransmitPacket)
+ }
+ if expectTransmitPacket {
+ test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index c8b9c9b5c..2e6ae55ea 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -316,13 +316,11 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
})
}
-// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on
-// the provided endpoint.
-func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+func rxICMPv4Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv4Type) {
totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
hdr := buffer.NewPrependable(totalLen)
pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- pkt.SetType(header.ICMPv4Echo)
+ pkt.SetType(ty)
pkt.SetCode(header.ICMPv4UnusedCode)
pkt.SetChecksum(0)
pkt.SetChecksum(^header.Checksum(pkt, 0))
@@ -341,13 +339,23 @@ func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8)
}))
}
-// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on
+// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on
// the provided endpoint.
-func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+ rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4Echo)
+}
+
+// RxICMPv4EchoReply constructs and injects an ICMPv4 echo reply packet on
+// the provided endpoint.
+func RxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+ rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4EchoReply)
+}
+
+func rxICMPv6Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv6Type) {
totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
hdr := buffer.NewPrependable(totalLen)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
- pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetType(ty)
pkt.SetCode(header.ICMPv6UnusedCode)
pkt.SetChecksum(0)
pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
@@ -368,3 +376,15 @@ func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8)
Data: hdr.View().ToVectorisedView(),
}))
}
+
+// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on
+// the provided endpoint.
+func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+ rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoRequest)
+}
+
+// RxICMPv6EchoReply constructs and injects an ICMPv6 echo reply packet on
+// the provided endpoint.
+func RxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+ rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoReply)
+}