diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/sentry/socket/netstack/stack.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 34 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/forward_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/loopback_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/tests/utils/utils.go | 8 |
13 files changed, 59 insertions, 52 deletions
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index b215067cf..9cc1c57d7 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -470,11 +470,8 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { // SetForwarding implements inet.Stack.SetForwarding. func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { - switch protocol { - case ipv4.ProtocolNumber, ipv6.ProtocolNumber: - s.Stack.SetForwarding(protocol, enable) - default: - panic(fmt.Sprintf("SetForwarding(%v) failed: unsupported protocol", protocol)) + if err := s.Stack.SetForwardingDefaultAndAllNICs(protocol, enable); err != nil { + return fmt.Errorf("SetForwardingDefaultAndAllNICs(%d, %t): %s", protocol, enable, err) } return nil } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 74aad126c..bd63e0289 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -1996,8 +1996,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) } - if err := s.SetForwarding(test.netProto, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netProto, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", test.netProto, err) } if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) @@ -2005,8 +2005,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr) } - if err := s.SetForwarding(test.netProto, false); err != nil { - t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netProto, false); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, false): %s", test.netProto, err) } if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 3c8a39973..5f45b9ee6 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -389,8 +389,8 @@ func TestForwarding(t *testing.T) { }, }) - if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(header.IPv4ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", header.IPv4ProtocolNumber, err) } ipHeaderLength := header.IPv4MinimumSize + len(test.options) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index e457be3cf..040cd4bc8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -673,8 +673,9 @@ func TestICMPChecksumValidationSimple(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) + } } if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index faf6a782e..30325160a 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -3329,8 +3329,8 @@ func TestForwarding(t *testing.T) { }, }) - if err := s.SetForwarding(ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) } transportProtocol := header.ICMPv6ProtocolNumber diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 52b9a200c..570c6c00c 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -878,8 +878,9 @@ func TestNDPValidation(t *testing.T) { s, ep := setup(t) if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) + } } stats := s.Stats().ICMP.V6.PacketsReceived diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 7d3725681..ff555722e 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -367,8 +367,10 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f }}, }) - // Enable forwarding. - s.SetForwarding(proto.Number(), true) + protoNum := proto.Number() + if err := s.SetForwardingDefaultAndAllNICs(protoNum, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", protoNum, err) + } // NIC 1 has the link address "a", and added the network address 1. ep1 = &fwdTestLinkEndpoint{ diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index c585b81b2..d4ac9e1f8 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1220,8 +1220,8 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { NDPDisp: &ndpDisp, })}, }) - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } e := channel.New(1, 1280, linkAddr1) @@ -1424,8 +1424,8 @@ func TestRouterDiscovery(t *testing.T) { } } - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } if err := s.CreateNIC(1, e); err != nil { @@ -1626,8 +1626,8 @@ func TestPrefixDiscovery(t *testing.T) { } } - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } // Receive an RA with prefix1 in an NDP Prefix Information option (PI) @@ -1893,8 +1893,8 @@ func TestAutoGenAddr(t *testing.T) { })}, }) - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } if err := s.CreateNIC(1, e); err != nil { @@ -4771,8 +4771,8 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { // or routers, or auto-generated address. for _, forwarding := range [...]bool{true, false} { t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) { - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } select { case e := <-ndpDisp.routerC: @@ -5353,8 +5353,8 @@ func TestRouterSolicitation(t *testing.T) { name: "Handle RAs always", handleRAs: ipv6.HandlingRAsAlwaysEnabled, afterFirstRS: func(t *testing.T, s *stack.Stack) { - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } }, }, @@ -5481,11 +5481,17 @@ func TestStopStartSolicitingRouters(t *testing.T) { name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, false) + + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, false); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", ipv6.ProtocolNumber, err) + } }, stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, true) + + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } }, }, diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3d9e1e286..483a960c8 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -492,9 +492,9 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables packet forwarding between NICs for the -// passed protocol. -func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { +// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the +// passed protocol and sets the default setting for newly created NICs. +func (s *Stack) SetForwardingDefaultAndAllNICs(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { protocol, ok := s.networkProtocols[protocolNum] if !ok { return &tcpip.ErrUnknownProtocol{} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d2c40cc43..ff88b1bd3 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -4220,8 +4220,8 @@ func TestFindRouteWithForwarding(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) } - if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) } s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) @@ -4275,8 +4275,8 @@ func TestFindRouteWithForwarding(t *testing.T) { // Disabling forwarding when the route is dependent on forwarding being // enabled should make the route invalid. - if err := s.SetForwarding(test.netCfg.proto, false); err != nil { - t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, false); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", test.netCfg.proto, err) } { err := send(r, data) diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index dbd279c94..42bc53328 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -475,11 +475,11 @@ func TestMulticastForwarding(t *testing.T) { t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) } - if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) } - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } s.SetRouteTable([]tcpip.Route{ diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 3df1bbd68..87d36e1dd 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -714,11 +714,11 @@ func TestExternalLoopbackTraffic(t *testing.T) { } if test.forwarding { - if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) } - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } } diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index 8fd9be32b..c8b9c9b5c 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -224,11 +224,11 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err) } - if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv4.ProtocolNumber, err) } - if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + if err := routerStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err) } if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil { |