From df2352796d1cbe5eea563d54380be60be18455bc Mon Sep 17 00:00:00 2001
From: Ghanan Gowripalan <ghanan@google.com>
Date: Fri, 14 May 2021 16:29:33 -0700
Subject: Control forwarding per NetworkEndpoint

...instead of per NetworkProtocol to better conform with linux
(https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt):

```
conf/interface/*

forwarding - BOOLEAN
	Enable IP forwarding on this interface.  This controls whether packets
	received _on_ this interface can be forwarded.
```

Fixes #5932.

PiperOrigin-RevId: 373888000
---
 pkg/tcpip/tests/integration/forward_test.go | 248 ++++++++++++++++++++++++----
 1 file changed, 215 insertions(+), 33 deletions(-)

(limited to 'pkg/tcpip/tests/integration')

diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 42bc53328..92fa6257d 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -16,6 +16,7 @@ package forward_test
 
 import (
 	"bytes"
+	"fmt"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
@@ -34,6 +35,39 @@ import (
 	"gvisor.dev/gvisor/pkg/waiter"
 )
 
+const ttl = 64
+
+var (
+	ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
+	ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
+)
+
+func rxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+	utils.RxICMPv4EchoRequest(e, src, dst, ttl)
+}
+
+func rxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+	utils.RxICMPv6EchoRequest(e, src, dst, ttl)
+}
+
+func forwardedICMPv4EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+	checker.IPv4(t, b,
+		checker.SrcAddr(src),
+		checker.DstAddr(dst),
+		checker.TTL(ttl-1),
+		checker.ICMPv4(
+			checker.ICMPv4Type(header.ICMPv4Echo)))
+}
+
+func forwardedICMPv6EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+	checker.IPv6(t, b,
+		checker.SrcAddr(src),
+		checker.DstAddr(dst),
+		checker.TTL(ttl-1),
+		checker.ICMPv6(
+			checker.ICMPv6Type(header.ICMPv6EchoRequest)))
+}
+
 func TestForwarding(t *testing.T) {
 	const listenPort = 8080
 
@@ -320,45 +354,16 @@ func TestMulticastForwarding(t *testing.T) {
 	const (
 		nicID1 = 1
 		nicID2 = 2
-		ttl    = 64
 	)
 
 	var (
 		ipv4LinkLocalUnicastAddr   = testutil.MustParse4("169.254.0.10")
 		ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10")
-		ipv4GlobalMulticastAddr    = testutil.MustParse4("224.0.1.10")
 
 		ipv6LinkLocalUnicastAddr   = testutil.MustParse6("fe80::a")
 		ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a")
-		ipv6GlobalMulticastAddr    = testutil.MustParse6("ff0e::a")
 	)
 
-	rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
-		utils.RxICMPv4EchoRequest(e, src, dst, ttl)
-	}
-
-	rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
-		utils.RxICMPv6EchoRequest(e, src, dst, ttl)
-	}
-
-	v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
-		checker.IPv4(t, b,
-			checker.SrcAddr(src),
-			checker.DstAddr(dst),
-			checker.TTL(ttl-1),
-			checker.ICMPv4(
-				checker.ICMPv4Type(header.ICMPv4Echo)))
-	}
-
-	v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
-		checker.IPv6(t, b,
-			checker.SrcAddr(src),
-			checker.DstAddr(dst),
-			checker.TTL(ttl-1),
-			checker.ICMPv6(
-				checker.ICMPv6Type(header.ICMPv6EchoRequest)))
-	}
-
 	tests := []struct {
 		name             string
 		srcAddr, dstAddr tcpip.Address
@@ -394,7 +399,7 @@ func TestMulticastForwarding(t *testing.T) {
 			rx:            rxICMPv4EchoRequest,
 			expectForward: true,
 			checker: func(t *testing.T, b []byte) {
-				v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+				forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
 			},
 		},
 		{
@@ -404,7 +409,7 @@ func TestMulticastForwarding(t *testing.T) {
 			rx:            rxICMPv4EchoRequest,
 			expectForward: true,
 			checker: func(t *testing.T, b []byte) {
-				v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
+				forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
 			},
 		},
 
@@ -436,7 +441,7 @@ func TestMulticastForwarding(t *testing.T) {
 			rx:            rxICMPv6EchoRequest,
 			expectForward: true,
 			checker: func(t *testing.T, b []byte) {
-				v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+				forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
 			},
 		},
 		{
@@ -446,7 +451,7 @@ func TestMulticastForwarding(t *testing.T) {
 			rx:            rxICMPv6EchoRequest,
 			expectForward: true,
 			checker: func(t *testing.T, b []byte) {
-				v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
+				forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
 			},
 		},
 	}
@@ -506,3 +511,180 @@ func TestMulticastForwarding(t *testing.T) {
 		})
 	}
 }
+
+func TestPerInterfaceForwarding(t *testing.T) {
+	const (
+		nicID1 = 1
+		nicID2 = 2
+	)
+
+	tests := []struct {
+		name             string
+		srcAddr, dstAddr tcpip.Address
+		rx               func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+		checker          func(*testing.T, []byte)
+	}{
+		{
+			name:    "IPv4 unicast",
+			srcAddr: utils.RemoteIPv4Addr,
+			dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
+			rx:      rxICMPv4EchoRequest,
+			checker: func(t *testing.T, b []byte) {
+				forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+			},
+		},
+		{
+			name:    "IPv4 multicast",
+			srcAddr: utils.RemoteIPv4Addr,
+			dstAddr: ipv4GlobalMulticastAddr,
+			rx:      rxICMPv4EchoRequest,
+			checker: func(t *testing.T, b []byte) {
+				forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
+			},
+		},
+
+		{
+			name:    "IPv6 unicast",
+			srcAddr: utils.RemoteIPv6Addr,
+			dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
+			rx:      rxICMPv6EchoRequest,
+			checker: func(t *testing.T, b []byte) {
+				forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+			},
+		},
+		{
+			name:    "IPv6 multicast",
+			srcAddr: utils.RemoteIPv6Addr,
+			dstAddr: ipv6GlobalMulticastAddr,
+			rx:      rxICMPv6EchoRequest,
+			checker: func(t *testing.T, b []byte) {
+				forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
+			},
+		},
+	}
+
+	netProtos := [...]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			s := stack.New(stack.Options{
+				NetworkProtocols: []stack.NetworkProtocolFactory{
+					// ARP is not used in this test but it is a network protocol that does
+					// not support forwarding. We install the protocol to make sure that
+					// forwarding information for a NIC is only reported for network
+					// protocols that support forwarding.
+					arp.NewProtocol,
+
+					ipv4.NewProtocol,
+					ipv6.NewProtocol,
+				},
+				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+			})
+
+			e1 := channel.New(1, header.IPv6MinimumMTU, "")
+			if err := s.CreateNIC(nicID1, e1); err != nil {
+				t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
+			}
+
+			e2 := channel.New(1, header.IPv6MinimumMTU, "")
+			if err := s.CreateNIC(nicID2, e2); err != nil {
+				t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
+			}
+
+			for _, add := range [...]struct {
+				nicID tcpip.NICID
+				addr  tcpip.ProtocolAddress
+			}{
+				{
+					nicID: nicID1,
+					addr:  utils.RouterNIC1IPv4Addr,
+				},
+				{
+					nicID: nicID1,
+					addr:  utils.RouterNIC1IPv6Addr,
+				},
+				{
+					nicID: nicID2,
+					addr:  utils.RouterNIC2IPv4Addr,
+				},
+				{
+					nicID: nicID2,
+					addr:  utils.RouterNIC2IPv6Addr,
+				},
+			} {
+				if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil {
+					t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err)
+				}
+			}
+
+			// Only enable forwarding on NIC1 and make sure that only packets arriving
+			// on NIC1 are forwarded.
+			for _, netProto := range netProtos {
+				if err := s.SetNICForwarding(nicID1, netProto, true); err != nil {
+					t.Fatalf("s.SetNICForwarding(%d, %d, true): %s", nicID1, netProtos, err)
+				}
+			}
+
+			nicsInfo := s.NICInfo()
+			for _, subTest := range [...]struct {
+				nicID            tcpip.NICID
+				nicEP            *channel.Endpoint
+				otherNICID       tcpip.NICID
+				otherNICEP       *channel.Endpoint
+				expectForwarding bool
+			}{
+				{
+					nicID:            nicID1,
+					nicEP:            e1,
+					otherNICID:       nicID2,
+					otherNICEP:       e2,
+					expectForwarding: true,
+				},
+				{
+					nicID:            nicID2,
+					nicEP:            e2,
+					otherNICID:       nicID2,
+					otherNICEP:       e1,
+					expectForwarding: false,
+				},
+			} {
+				t.Run(fmt.Sprintf("Packet arriving at NIC%d", subTest.nicID), func(t *testing.T) {
+					nicInfo, ok := nicsInfo[subTest.nicID]
+					if !ok {
+						t.Errorf("expected NIC info for NIC %d; got = %#v", subTest.nicID, nicsInfo)
+					} else {
+						forwarding := make(map[tcpip.NetworkProtocolNumber]bool)
+						for _, netProto := range netProtos {
+							forwarding[netProto] = subTest.expectForwarding
+						}
+
+						if diff := cmp.Diff(forwarding, nicInfo.Forwarding); diff != "" {
+							t.Errorf("nicsInfo[%d].Forwarding mismatch (-want +got):\n%s", subTest.nicID, diff)
+						}
+					}
+
+					s.SetRouteTable([]tcpip.Route{
+						{
+							Destination: header.IPv4EmptySubnet,
+							NIC:         subTest.otherNICID,
+						},
+						{
+							Destination: header.IPv6EmptySubnet,
+							NIC:         subTest.otherNICID,
+						},
+					})
+
+					test.rx(subTest.nicEP, test.srcAddr, test.dstAddr)
+					if p, ok := subTest.nicEP.Read(); ok {
+						t.Errorf("unexpectedly got a response from the interface the packet arrived on: %#v", p)
+					}
+					if p, ok := subTest.otherNICEP.Read(); ok != subTest.expectForwarding {
+						t.Errorf("got otherNICEP.Read() = (%#v, %t), want = (_, %t)", p, ok, subTest.expectForwarding)
+					} else if subTest.expectForwarding {
+						test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+					}
+				})
+			}
+		})
+	}
+}
-- 
cgit v1.2.3