summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/tests
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-05-13 18:52:06 -0700
committergVisor bot <gvisor-bot@google.com>2021-05-13 18:54:09 -0700
commit2b457d9ee9ba50da4a9208d957053fac2c77932d (patch)
tree8c711a691ef53a677083af1f884d664284f30f9c /pkg/tcpip/tests
parent7ea2dcbaece00b5c7310c74fcf99c1fb32e9ec28 (diff)
Check filter table when forwarding IP packets
This change updates the forwarding path to perform the forwarding hook with iptables so that the filter table is consulted before a packet is forwarded Updates #170. Test: iptables_test.TestForwardingHook PiperOrigin-RevId: 373702359
Diffstat (limited to 'pkg/tcpip/tests')
-rw-r--r--pkg/tcpip/tests/integration/BUILD2
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go296
-rw-r--r--pkg/tcpip/tests/utils/utils.go34
3 files changed, 325 insertions, 7 deletions
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index d4f7bb5ff..ab2dab60c 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -31,12 +31,14 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/udp",
],
)
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index c61d4e788..07ba2b837 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -19,12 +19,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
@@ -645,3 +647,297 @@ func TestIPTableWritePackets(t *testing.T) {
})
}
}
+
+const ttl = 64
+
+var (
+ ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
+ ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
+)
+
+func rxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv4EchoReply(e, src, dst, ttl)
+}
+
+func rxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv6EchoReply(e, src, dst, ttl)
+}
+
+func forwardedICMPv4EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply)))
+}
+
+func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply)))
+}
+
+func TestForwardingHook(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ nic1Name = "nic1"
+ nic2Name = "nic2"
+
+ otherNICName = "otherNIC"
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ local bool
+ srcAddr, dstAddr tcpip.Address
+ rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ checker func(*testing.T, []byte)
+ }{
+ {
+ name: "IPv4 remote",
+ netProto: ipv4.ProtocolNumber,
+ local: false,
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv4EchoReply,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv4EchoReplyChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv4 local",
+ netProto: ipv4.ProtocolNumber,
+ local: true,
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4Addr.Address,
+ rx: rxICMPv4EchoReply,
+ },
+ {
+ name: "IPv6 remote",
+ netProto: ipv6.ProtocolNumber,
+ local: false,
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv6EchoReply,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv6EchoReplyChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv6 local",
+ netProto: ipv6.ProtocolNumber,
+ local: true,
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr.Address,
+ rx: rxICMPv6EchoReply,
+ },
+ }
+
+ setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
+ return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, ipv6)
+ ruleIdx := filter.BuiltinChains[stack.Forward]
+ filter.Rules[ruleIdx].Filter = f
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
+ }
+ }
+ }
+
+ boolToInt := func(v bool) uint64 {
+ if v {
+ return 1
+ }
+ return 0
+ }
+
+ subTests := []struct {
+ name string
+ setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
+ expectForward bool
+ }{
+ {
+ name: "Accept",
+ setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
+ expectForward: true,
+ },
+
+ {
+ name: "Drop",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{}),
+ expectForward: false,
+ },
+ {
+ name: "Drop with input NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}),
+ expectForward: false,
+ },
+ {
+ name: "Drop with output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}),
+ expectForward: false,
+ },
+ {
+ name: "Drop with input and output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
+ expectForward: false,
+ },
+
+ {
+ name: "Drop with other input NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}),
+ expectForward: true,
+ },
+ {
+ name: "Drop with other output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}),
+ expectForward: true,
+ },
+ {
+ name: "Drop with other input and output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
+ expectForward: true,
+ },
+ {
+ name: "Drop with input and other output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
+ expectForward: true,
+ },
+ {
+ name: "Drop with other input and other output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
+ expectForward: true,
+ },
+
+ {
+ name: "Drop with inverted input NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
+ expectForward: true,
+ },
+ {
+ name: "Drop with inverted output NIC filtering",
+ setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
+ expectForward: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ })
+
+ subTest.setupFilter(t, s, test.netProto)
+
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
+ }
+
+ if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ }
+ if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ }
+
+ if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID2,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID2,
+ },
+ })
+
+ test.rx(e1, test.srcAddr, test.dstAddr)
+
+ expectTransmitPacket := subTest.expectForward && !test.local
+
+ ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
+ }
+ ep1Stats := ep1.Stats()
+ ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
+ }
+ ip1Stats := ipEP1Stats.IPStats()
+
+ if got := ip1Stats.PacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got, want := ip1Stats.IPTablesForwardDropped.Value(), boolToInt(!subTest.expectForward); got != want {
+ t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, want)
+ }
+ if got := ip1Stats.PacketsSent.Value(); got != 0 {
+ t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = 0", got)
+ }
+
+ ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
+ }
+ ep2Stats := ep2.Stats()
+ ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
+ }
+ ip2Stats := ipEP2Stats.IPStats()
+ if got := ip2Stats.PacketsReceived.Value(); got != 0 {
+ t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
+ }
+ if got, want := ip2Stats.ValidPacketsReceived.Value(), boolToInt(subTest.expectForward && test.local); got != want {
+ t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, want)
+ }
+ if got, want := ip2Stats.PacketsSent.Value(), boolToInt(expectTransmitPacket); got != want {
+ t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = %d", got, want)
+ }
+
+ p, ok := e2.Read()
+ if ok != expectTransmitPacket {
+ t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, expectTransmitPacket)
+ }
+ if expectTransmitPacket {
+ test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index c8b9c9b5c..2e6ae55ea 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -316,13 +316,11 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
})
}
-// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on
-// the provided endpoint.
-func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+func rxICMPv4Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv4Type) {
totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
hdr := buffer.NewPrependable(totalLen)
pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- pkt.SetType(header.ICMPv4Echo)
+ pkt.SetType(ty)
pkt.SetCode(header.ICMPv4UnusedCode)
pkt.SetChecksum(0)
pkt.SetChecksum(^header.Checksum(pkt, 0))
@@ -341,13 +339,23 @@ func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8)
}))
}
-// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on
+// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on
// the provided endpoint.
-func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+ rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4Echo)
+}
+
+// RxICMPv4EchoReply constructs and injects an ICMPv4 echo reply packet on
+// the provided endpoint.
+func RxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+ rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4EchoReply)
+}
+
+func rxICMPv6Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv6Type) {
totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
hdr := buffer.NewPrependable(totalLen)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
- pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetType(ty)
pkt.SetCode(header.ICMPv6UnusedCode)
pkt.SetChecksum(0)
pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
@@ -368,3 +376,15 @@ func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8)
Data: hdr.View().ToVectorisedView(),
}))
}
+
+// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on
+// the provided endpoint.
+func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+ rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoRequest)
+}
+
+// RxICMPv6EchoReply constructs and injects an ICMPv6 echo reply packet on
+// the provided endpoint.
+func RxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
+ rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoReply)
+}