diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/network/internal/ip/stats.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 21 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_types.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/BUILD | 2 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/iptables_test.go | 296 | ||||
-rw-r--r-- | pkg/tcpip/tests/utils/utils.go | 34 |
9 files changed, 397 insertions, 12 deletions
diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go index 444515d40..0c2b62127 100644 --- a/pkg/tcpip/network/internal/ip/stats.go +++ b/pkg/tcpip/network/internal/ip/stats.go @@ -74,6 +74,10 @@ type MultiCounterIPStats struct { // layer. PacketsReceived tcpip.MultiCounterStat + // ValidPacketsReceived is the number of valid IP packets that reached the IP + // layer. + ValidPacketsReceived tcpip.MultiCounterStat + // DisabledPacketsReceived is the number of IP packets received from // the link layer when the IP layer is disabled. DisabledPacketsReceived tcpip.MultiCounterStat @@ -114,6 +118,10 @@ type MultiCounterIPStats struct { // Input chain. IPTablesInputDropped tcpip.MultiCounterStat + // IPTablesForwardDropped is the number of IP packets dropped in the + // Forward chain. + IPTablesForwardDropped tcpip.MultiCounterStat + // IPTablesOutputDropped is the number of IP packets dropped in the // Output chain. IPTablesOutputDropped tcpip.MultiCounterStat @@ -146,6 +154,7 @@ type MultiCounterIPStats struct { // Init sets internal counters to track a and b counters. func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived) + m.ValidPacketsReceived.Init(a.ValidPacketsReceived, b.ValidPacketsReceived) m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived) m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived) m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived) @@ -156,6 +165,7 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived) m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped) m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped) + m.IPTablesForwardDropped.Init(a.IPTablesForwardDropped, b.IPTablesForwardDropped) m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped) m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped) m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index aef83e834..049811cbb 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -668,13 +668,23 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { } } + stk := e.protocol.stack + // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(ep.nic.ID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + ep.handleValidatedPacket(h, pkt) return nil } - r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) switch err.(type) { case nil: case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable: @@ -688,6 +698,14 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { } defer r.Release() + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(r.NICID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + // We need to do a deep copy of the IP packet because // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do // not own it. @@ -803,6 +821,7 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats + stats.ip.ValidPacketsReceived.Increment() srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index febbb3f38..f0e06f86b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -941,8 +941,18 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { return &ip.ErrTTLExceeded{} } + stk := e.protocol.stack + // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(ep.nic.ID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + ep.handleValidatedPacket(h, pkt) return nil } @@ -952,7 +962,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { return &ip.ErrParameterProblem{} } - r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) switch err.(type) { case nil: case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable: @@ -965,6 +975,14 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { } defer r.Release() + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(r.NICID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + // We need to do a deep copy of the IP packet because // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do // not own it. @@ -1073,6 +1091,8 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats.ip + stats.ValidPacketsReceived.Increment() + srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index e2894c548..3670d5995 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -177,6 +177,7 @@ func DefaultTables() *IPTables { priorities: [NumHooks][]TableID{ Prerouting: {MangleID, NATID}, Input: {NATID, FilterID}, + Forward: {FilterID}, Output: {MangleID, NATID, FilterID}, Postrouting: {MangleID, NATID}, }, diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 4631ab93f..93592e7f5 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -280,9 +280,18 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) case Output: return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) - case Forward, Postrouting: - // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING - // hooks after supported. + case Forward: + if !matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) { + return false + } + + if !matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) { + return false + } + + return true + case Postrouting: + // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING. return true default: panic(fmt.Sprintf("unknown hook: %d", hook)) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 7b9c8cd4f..797778e08 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1571,6 +1571,10 @@ type IPStats struct { // PacketsReceived is the number of IP packets received from the link layer. PacketsReceived *StatCounter + // ValidPacketsReceived is the number of valid IP packets that reached the IP + // layer. + ValidPacketsReceived *StatCounter + // DisabledPacketsReceived is the number of IP packets received from the link // layer when the IP layer is disabled. DisabledPacketsReceived *StatCounter @@ -1610,6 +1614,10 @@ type IPStats struct { // chain. IPTablesInputDropped *StatCounter + // IPTablesForwardDropped is the number of IP packets dropped in the Forward + // chain. + IPTablesForwardDropped *StatCounter + // IPTablesOutputDropped is the number of IP packets dropped in the Output // chain. IPTablesOutputDropped *StatCounter diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index d4f7bb5ff..ab2dab60c 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -31,12 +31,14 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/udp", ], ) diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index c61d4e788..07ba2b837 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -19,12 +19,14 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) @@ -645,3 +647,297 @@ func TestIPTableWritePackets(t *testing.T) { }) } } + +const ttl = 64 + +var ( + ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") + ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") +) + +func rxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoReply(e, src, dst, ttl) +} + +func rxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoReply(e, src, dst, ttl) +} + +func forwardedICMPv4EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply))) +} + +func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply))) +} + +func TestForwardingHook(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + nic1Name = "nic1" + nic2Name = "nic2" + + otherNICName = "otherNIC" + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + local bool + srcAddr, dstAddr tcpip.Address + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + checker func(*testing.T, []byte) + }{ + { + name: "IPv4 remote", + netProto: ipv4.ProtocolNumber, + local: false, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, + rx: rxICMPv4EchoReply, + checker: func(t *testing.T, b []byte) { + forwardedICMPv4EchoReplyChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv4 local", + netProto: ipv4.ProtocolNumber, + local: true, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4Addr.Address, + rx: rxICMPv4EchoReply, + }, + { + name: "IPv6 remote", + netProto: ipv6.ProtocolNumber, + local: false, + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, + rx: rxICMPv6EchoReply, + checker: func(t *testing.T, b []byte) { + forwardedICMPv6EchoReplyChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv6 local", + netProto: ipv6.ProtocolNumber, + local: true, + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr.Address, + rx: rxICMPv6EchoReply, + }, + } + + setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { + return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { + t.Helper() + + ipv6 := netProto == ipv6.ProtocolNumber + + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, ipv6) + ruleIdx := filter.BuiltinChains[stack.Forward] + filter.Rules[ruleIdx].Filter = f + filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} + if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) + } + } + } + + boolToInt := func(v bool) uint64 { + if v { + return 1 + } + return 0 + } + + subTests := []struct { + name string + setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) + expectForward bool + }{ + { + name: "Accept", + setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ }, + expectForward: true, + }, + + { + name: "Drop", + setupFilter: setupDropFilter(stack.IPHeaderFilter{}), + expectForward: false, + }, + { + name: "Drop with input NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}), + expectForward: false, + }, + { + name: "Drop with output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}), + expectForward: false, + }, + { + name: "Drop with input and output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), + expectForward: false, + }, + + { + name: "Drop with other input NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}), + expectForward: true, + }, + { + name: "Drop with other output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}), + expectForward: true, + }, + { + name: "Drop with other input and output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), + expectForward: true, + }, + { + name: "Drop with input and other output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), + expectForward: true, + }, + { + name: "Drop with other input and other output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), + expectForward: true, + }, + + { + name: "Drop with inverted input NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), + expectForward: true, + }, + { + name: "Drop with inverted output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), + expectForward: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + }) + + subTest.setupFilter(t, s, test.netProto) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) + } + + if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) + } + if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) + } + + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID2, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID2, + }, + }) + + test.rx(e1, test.srcAddr, test.dstAddr) + + expectTransmitPacket := subTest.expectForward && !test.local + + ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err) + } + ep1Stats := ep1.Stats() + ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats) + } + ip1Stats := ipEP1Stats.IPStats() + + if got := ip1Stats.PacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got) + } + if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got) + } + if got, want := ip1Stats.IPTablesForwardDropped.Value(), boolToInt(!subTest.expectForward); got != want { + t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, want) + } + if got := ip1Stats.PacketsSent.Value(); got != 0 { + t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = 0", got) + } + + ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err) + } + ep2Stats := ep2.Stats() + ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats) + } + ip2Stats := ipEP2Stats.IPStats() + if got := ip2Stats.PacketsReceived.Value(); got != 0 { + t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got) + } + if got, want := ip2Stats.ValidPacketsReceived.Value(), boolToInt(subTest.expectForward && test.local); got != want { + t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, want) + } + if got, want := ip2Stats.PacketsSent.Value(), boolToInt(expectTransmitPacket); got != want { + t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = %d", got, want) + } + + p, ok := e2.Read() + if ok != expectTransmitPacket { + t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, expectTransmitPacket) + } + if expectTransmitPacket { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index c8b9c9b5c..2e6ae55ea 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -316,13 +316,11 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. }) } -// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on -// the provided endpoint. -func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { +func rxICMPv4Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv4Type) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize hdr := buffer.NewPrependable(totalLen) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4Echo) + pkt.SetType(ty) pkt.SetCode(header.ICMPv4UnusedCode) pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, 0)) @@ -341,13 +339,23 @@ func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) })) } -// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on +// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on // the provided endpoint. -func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { +func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { + rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4Echo) +} + +// RxICMPv4EchoReply constructs and injects an ICMPv4 echo reply packet on +// the provided endpoint. +func RxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { + rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4EchoReply) +} + +func rxICMPv6Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv6Type) { totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize hdr := buffer.NewPrependable(totalLen) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetType(ty) pkt.SetCode(header.ICMPv6UnusedCode) pkt.SetChecksum(0) pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ @@ -368,3 +376,15 @@ func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) Data: hdr.View().ToVectorisedView(), })) } + +// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on +// the provided endpoint. +func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { + rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoRequest) +} + +// RxICMPv6EchoReply constructs and injects an ICMPv6 echo reply packet on +// the provided endpoint. +func RxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { + rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoReply) +} |