From baa0888f114c586ea490d49a23c3d828fd739b85 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 13 May 2021 11:22:25 -0700 Subject: Rename SetForwarding to SetForwardingDefaultAndAllNICs ...to make it clear to callers that all interfaces are updated with the forwarding flag and that future NICs will be created with the new forwarding state. PiperOrigin-RevId: 373618435 --- pkg/sentry/socket/netstack/stack.go | 7 ++---- pkg/tcpip/network/ip_test.go | 8 +++---- pkg/tcpip/network/ipv4/ipv4_test.go | 4 ++-- pkg/tcpip/network/ipv6/icmp_test.go | 5 ++-- pkg/tcpip/network/ipv6/ipv6_test.go | 4 ++-- pkg/tcpip/network/ipv6/ndp_test.go | 5 ++-- pkg/tcpip/stack/forwarding_test.go | 6 +++-- pkg/tcpip/stack/ndp_test.go | 34 ++++++++++++++++------------ pkg/tcpip/stack/stack.go | 6 ++--- pkg/tcpip/stack/stack_test.go | 8 +++---- pkg/tcpip/tests/integration/forward_test.go | 8 +++---- pkg/tcpip/tests/integration/loopback_test.go | 8 +++---- pkg/tcpip/tests/utils/utils.go | 8 +++---- 13 files changed, 59 insertions(+), 52 deletions(-) (limited to 'pkg') diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index b215067cf..9cc1c57d7 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -470,11 +470,8 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { // SetForwarding implements inet.Stack.SetForwarding. func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { - switch protocol { - case ipv4.ProtocolNumber, ipv6.ProtocolNumber: - s.Stack.SetForwarding(protocol, enable) - default: - panic(fmt.Sprintf("SetForwarding(%v) failed: unsupported protocol", protocol)) + if err := s.Stack.SetForwardingDefaultAndAllNICs(protocol, enable); err != nil { + return fmt.Errorf("SetForwardingDefaultAndAllNICs(%d, %t): %s", protocol, enable, err) } return nil } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 74aad126c..bd63e0289 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -1996,8 +1996,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) } - if err := s.SetForwarding(test.netProto, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netProto, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", test.netProto, err) } if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) @@ -2005,8 +2005,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr) } - if err := s.SetForwarding(test.netProto, false); err != nil { - t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netProto, false); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, false): %s", test.netProto, err) } if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 3c8a39973..5f45b9ee6 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -389,8 +389,8 @@ func TestForwarding(t *testing.T) { }, }) - if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(header.IPv4ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", header.IPv4ProtocolNumber, err) } ipHeaderLength := header.IPv4MinimumSize + len(test.options) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index e457be3cf..040cd4bc8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -673,8 +673,9 @@ func TestICMPChecksumValidationSimple(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) + } } if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index faf6a782e..30325160a 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -3329,8 +3329,8 @@ func TestForwarding(t *testing.T) { }, }) - if err := s.SetForwarding(ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) } transportProtocol := header.ICMPv6ProtocolNumber diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 52b9a200c..570c6c00c 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -878,8 +878,9 @@ func TestNDPValidation(t *testing.T) { s, ep := setup(t) if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) + } } stats := s.Stats().ICMP.V6.PacketsReceived diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 7d3725681..ff555722e 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -367,8 +367,10 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f }}, }) - // Enable forwarding. - s.SetForwarding(proto.Number(), true) + protoNum := proto.Number() + if err := s.SetForwardingDefaultAndAllNICs(protoNum, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", protoNum, err) + } // NIC 1 has the link address "a", and added the network address 1. ep1 = &fwdTestLinkEndpoint{ diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index c585b81b2..d4ac9e1f8 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1220,8 +1220,8 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { NDPDisp: &ndpDisp, })}, }) - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } e := channel.New(1, 1280, linkAddr1) @@ -1424,8 +1424,8 @@ func TestRouterDiscovery(t *testing.T) { } } - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } if err := s.CreateNIC(1, e); err != nil { @@ -1626,8 +1626,8 @@ func TestPrefixDiscovery(t *testing.T) { } } - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } // Receive an RA with prefix1 in an NDP Prefix Information option (PI) @@ -1893,8 +1893,8 @@ func TestAutoGenAddr(t *testing.T) { })}, }) - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } if err := s.CreateNIC(1, e); err != nil { @@ -4771,8 +4771,8 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { // or routers, or auto-generated address. for _, forwarding := range [...]bool{true, false} { t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) { - if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } select { case e := <-ndpDisp.routerC: @@ -5353,8 +5353,8 @@ func TestRouterSolicitation(t *testing.T) { name: "Handle RAs always", handleRAs: ipv6.HandlingRAsAlwaysEnabled, afterFirstRS: func(t *testing.T, s *stack.Stack) { - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } }, }, @@ -5481,11 +5481,17 @@ func TestStopStartSolicitingRouters(t *testing.T) { name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, false) + + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, false); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", ipv6.ProtocolNumber, err) + } }, stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, true) + + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } }, }, diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3d9e1e286..483a960c8 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -492,9 +492,9 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables packet forwarding between NICs for the -// passed protocol. -func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { +// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the +// passed protocol and sets the default setting for newly created NICs. +func (s *Stack) SetForwardingDefaultAndAllNICs(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { protocol, ok := s.networkProtocols[protocolNum] if !ok { return &tcpip.ErrUnknownProtocol{} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d2c40cc43..ff88b1bd3 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -4220,8 +4220,8 @@ func TestFindRouteWithForwarding(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) } - if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) } s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) @@ -4275,8 +4275,8 @@ func TestFindRouteWithForwarding(t *testing.T) { // Disabling forwarding when the route is dependent on forwarding being // enabled should make the route invalid. - if err := s.SetForwarding(test.netCfg.proto, false); err != nil { - t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, false); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", test.netCfg.proto, err) } { err := send(r, data) diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index dbd279c94..42bc53328 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -475,11 +475,11 @@ func TestMulticastForwarding(t *testing.T) { t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) } - if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) } - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } s.SetRouteTable([]tcpip.Route{ diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 3df1bbd68..87d36e1dd 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -714,11 +714,11 @@ func TestExternalLoopbackTraffic(t *testing.T) { } if test.forwarding { - if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) } - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } } diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index 8fd9be32b..c8b9c9b5c 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -224,11 +224,11 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err) } - if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv4.ProtocolNumber, err) } - if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + if err := routerStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err) } if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil { -- cgit v1.2.3