diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/BUILD | 3 | ||||
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state.go | 52 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 43 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_types.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 1173 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache_test.go | 66 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry_test.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 88 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic_stats.go | 74 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic_test.go | 32 | ||||
-rw-r--r-- | pkg/tcpip/stack/nud.go | 71 | ||||
-rw-r--r-- | pkg/tcpip/stack/nud_test.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 383 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer_test.go | 200 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 33 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 163 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 205 |
20 files changed, 1542 insertions, 1109 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 2bd6a67f5..395ff9a07 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -56,6 +56,7 @@ go_library( "neighbor_entry_list.go", "neighborstate_string.go", "nic.go", + "nic_stats.go", "nud.go", "packet_buffer.go", "packet_buffer_list.go", @@ -73,6 +74,8 @@ go_library( ], visibility = ["//visibility:public"], deps = [ + "//pkg/atomicbitops", + "//pkg/buffer", "//pkg/ilist", "//pkg/log", "//pkg/rand", diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index e5590ecc0..ce9cebdaa 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -440,33 +440,54 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad // Regardless how the address was obtained, it will be acquired before it is // returned. func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { - a.mu.Lock() - defer a.mu.Unlock() + lookup := func() *addressState { + if addrState, ok := a.mu.endpoints[localAddr]; ok { + if !addrState.IsAssigned(allowTemp) { + return nil + } - if addrState, ok := a.mu.endpoints[localAddr]; ok { - if !addrState.IsAssigned(allowTemp) { - return nil - } + if !addrState.IncRef() { + panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr)) + } - if !addrState.IncRef() { - panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr)) + return addrState } - return addrState - } - - if f != nil { - for _, addrState := range a.mu.endpoints { - if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() { - return addrState + if f != nil { + for _, addrState := range a.mu.endpoints { + if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() { + return addrState + } } } + return nil + } + // Avoid exclusive lock on mu unless we need to add a new address. + a.mu.RLock() + ep := lookup() + a.mu.RUnlock() + + if ep != nil { + return ep } if !allowTemp { return nil } + // Acquire state lock in exclusive mode as we need to add a new temporary + // endpoint. + a.mu.Lock() + defer a.mu.Unlock() + + // Do the lookup again in case another goroutine added the address in the time + // we released and acquired the lock. + ep = lookup() + if ep != nil { + return ep + } + + // Proceed to add a new temporary endpoint. addr := localAddr.WithPrefix() ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */) if err != nil { @@ -475,6 +496,7 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc // expect no error. panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err)) } + // From https://golang.org/doc/faq#nil_error: // // Under the covers, interfaces are implemented as two elements, a type T and diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 2d74e0abc..d971db010 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -54,6 +54,11 @@ type fwdTestNetworkEndpoint struct { nic NetworkInterface proto *fwdTestNetworkProtocol dispatcher TransportDispatcher + + mu struct { + sync.RWMutex + forwarding bool + } } func (*fwdTestNetworkEndpoint) Enable() tcpip.Error { @@ -101,7 +106,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: vv.ToView().ToVectorisedView(), }) - // TODO(b/143425874) Decrease the TTL field in forwarded packets. + // TODO(gvisor.dev/issue/1085) Decrease the TTL field in forwarded packets. _ = r.WriteHeaderIncludedPacket(pkt) } @@ -109,10 +114,6 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen } -func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -129,7 +130,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParam } // WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } @@ -169,11 +170,6 @@ type fwdTestNetworkProtocol struct { addrResolveDelay time.Duration onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) - - mu struct { - sync.RWMutex - forwarding bool - } } func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -242,16 +238,16 @@ func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber return fwdTestNetNumber } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (f *fwdTestNetworkProtocol) Forwarding() bool { +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (f *fwdTestNetworkEndpoint) Forwarding() bool { f.mu.RLock() defer f.mu.RUnlock() return f.mu.forwarding } -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (f *fwdTestNetworkProtocol) SetForwarding(v bool) { +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (f *fwdTestNetworkEndpoint) SetForwarding(v bool) { f.mu.Lock() defer f.mu.Unlock() f.mu.forwarding = v @@ -264,6 +260,8 @@ type fwdTestPacketInfo struct { Pkt *PacketBuffer } +var _ LinkEndpoint = (*fwdTestLinkEndpoint)(nil) + type fwdTestLinkEndpoint struct { dispatcher NetworkDispatcher mtu uint32 @@ -306,11 +304,6 @@ func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { return caps | CapabilityResolutionRequired } -// GSOMaxSize returns the maximum GSO packet size. -func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 { - return 1 << 15 -} - // MaxHeaderLength returns the maximum size of the link layer header. Given it // doesn't have a header, it just returns 0. func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { @@ -322,7 +315,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -357,7 +350,7 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { } // AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { +func (e *fwdTestLinkEndpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) { panic("not implemented") } @@ -370,8 +363,10 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f }}, }) - // Enable forwarding. - s.SetForwarding(proto.Number(), true) + protoNum := proto.Number() + if err := s.SetForwardingDefaultAndAllNICs(protoNum, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", protoNum, err) + } // NIC 1 has the link address "a", and added the network address 1. ep1 = &fwdTestLinkEndpoint{ diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index e2894c548..3670d5995 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -177,6 +177,7 @@ func DefaultTables() *IPTables { priorities: [NumHooks][]TableID{ Prerouting: {MangleID, NATID}, Input: {NATID, FilterID}, + Forward: {FilterID}, Output: {MangleID, NATID, FilterID}, Postrouting: {MangleID, NATID}, }, diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 4631ab93f..93592e7f5 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -280,9 +280,18 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) case Output: return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) - case Forward, Postrouting: - // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING - // hooks after supported. + case Forward: + if !matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) { + return false + } + + if !matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) { + return false + } + + return true + case Postrouting: + // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING. return true default: panic(fmt.Sprintf("unknown hook: %d", hook)) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index b6cf24739..ac2fa777e 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -481,13 +481,9 @@ func TestDADResolve(t *testing.T) { } for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } e := channelLinkWithHeaderLength{ @@ -499,7 +495,9 @@ func TestDADResolve(t *testing.T) { var secureRNG bytes.Reader secureRNG.Reset(secureRNGBytes) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ + Clock: clock, SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, @@ -529,14 +527,10 @@ func TestDADResolve(t *testing.T) { t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) } - // Address should not be considered bound to the NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - // Make sure the address does not resolve before the resolution time has // passed. - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout) + const delta = time.Nanosecond + clock.Advance(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Error(err) } @@ -566,13 +560,14 @@ func TestDADResolve(t *testing.T) { } // Wait for DAD to resolve. + clock.Advance(delta) select { - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + t.Fatalf("expected DAD event for %s on NIC(%d)", addr1, nicID) } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { t.Error(err) @@ -1146,57 +1141,198 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on }) } -// TestNoRouterDiscovery tests that router discovery will not be performed if -// configured not to. -func TestNoRouterDiscovery(t *testing.T) { - // Being configured to discover routers means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // router discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - DiscoverDefaultRouters: discover, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) +func TestDynamicConfigurationsDisabled(t *testing.T) { + const ( + nicID = 1 + maxRtrSolicitDelay = time.Second + ) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + prefix := tcpip.AddressWithPrefix{ + Address: testutil.MustParse6("102:304:506:708::"), + PrefixLen: 64, + } - // Rx an RA with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router when configured not to") - default: + tests := []struct { + name string + config func(bool) ipv6.NDPConfigurations + ra *stack.PacketBuffer + }{ + { + name: "No Router Discovery", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable} + }, + ra: raBuf(llAddr2, 1000), + }, + { + name: "No Prefix Discovery", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{DiscoverOnLinkPrefixes: enable} + }, + ra: raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0), + }, + { + name: "No Autogenerate Addresses", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{AutoGenGlobalAddresses: enable} + }, + ra: raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Being configured to discover routers/prefixes or auto-generate + // addresses means RAs must be handled, and router/prefix discovery or + // SLAAC must be enabled. + // + // This tests all possible combinations of the configurations where + // router/prefix discovery or SLAAC are disabled. + for i := 0; i < 7; i++ { + handle := ipv6.HandlingRAsDisabled + if i&1 != 0 { + handle = ipv6.HandlingRAsEnabledWhenForwardingDisabled + } + enable := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + ndpConfigs := test.config(enable) + ndpConfigs.HandleRAs = handle + ndpConfigs.MaxRtrSolicitations = 1 + ndpConfigs.RtrSolicitationInterval = maxRtrSolicitDelay + ndpConfigs.MaxRtrSolicitationDelay = maxRtrSolicitDelay + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + Clock: clock, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + })}, + }) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } + + e := channel.New(1, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + handleRAsDisabled := handle == ipv6.HandlingRAsDisabled || forwarding + ep, err := s.GetNetworkEndpoint(nicID, ipv6.ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ipv6.ProtocolNumber, err) + } + stats := ep.Stats() + v6Stats, ok := stats.(*ipv6.Stats) + if !ok { + t.Fatalf("got v6Stats = %T, expected = %T", stats, v6Stats) + } + + // Make sure that when handling RAs are enabled, we solicit routers. + clock.Advance(maxRtrSolicitDelay) + if got, want := v6Stats.ICMP.PacketsSent.RouterSolicit.Value(), boolToUint64(!handleRAsDisabled); got != want { + t.Errorf("got v6Stats.ICMP.PacketsSent.RouterSolicit.Value() = %d, want = %d", got, want) + } + if handleRAsDisabled { + if p, ok := e.Read(); ok { + t.Errorf("unexpectedly got a packet = %#v", p) + } + } else if p, ok := e.Read(); !ok { + t.Error("expected router solicitation packet") + } else if p.Proto != header.IPv6ProtocolNumber { + t.Errorf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } else { + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(nil)), + ) + } + + // Make sure we do not discover any routers or prefixes, or perform + // SLAAC on reception of an RA. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra.Clone()) + // Make sure that the unhandled RA stat is only incremented when + // handling RAs is disabled. + if got, want := v6Stats.UnhandledRouterAdvertisements.Value(), boolToUint64(handleRAsDisabled); got != want { + t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) + } + select { + case e := <-ndpDisp.routerC: + t.Errorf("unexpectedly discovered a router when configured not to: %#v", e) + default: + } + select { + case e := <-ndpDisp.prefixC: + t.Errorf("unexpectedly discovered a prefix when configured not to: %#v", e) + default: + } + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("unexpectedly auto-generated an address when configured not to: %#v", e) + default: + } + }) } }) } } +func boolToUint64(v bool) uint64 { + if v { + return 1 + } + return 0 +} + // Check e to make sure that the event is for addr on nic with ID 1, and the // discovered flag set to discovered. func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) } +func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) { + tests := [...]struct { + name string + handleRAs ipv6.HandleRAsConfiguration + forwarding bool + }{ + { + name: "Handle RAs when forwarding disabled", + handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + forwarding: false, + }, + { + name: "Always Handle RAs with forwarding disabled", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + forwarding: false, + }, + { + name: "Always Handle RAs with forwarding enabled", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + forwarding: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f(t, test.handleRAs, test.forwarding) + }) + } +} + // TestRouterDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered router when the dispatcher asks it not to. func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { @@ -1207,7 +1343,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, }, NDPDisp: &ndpDisp, @@ -1241,103 +1377,109 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } func TestRouterDiscovery(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - expectRouterEvent := func(addr tcpip.Address, discovered bool) { - t.Helper() + expectRouterEvent := func(addr tcpip.Address, discovered bool) { + t.Helper() - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, discovered); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, discovered); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") } - default: - t.Fatal("expected router discovery event") } - } - expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { - t.Helper() + expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + t.Helper() - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, false); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, false); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for router discovery event") } - case <-time.After(timeout): - t.Fatal("timed out waiting for router discovery event") } - } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA from lladdr2 with zero lifetime. It should not be - // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router with 0 lifetime") - default: - } - - // Rx an RA from lladdr2 with a huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - // Rx an RA from another router (lladdr3) with non-zero lifetime. - const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) - expectRouterEvent(llAddr3, true) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - // Rx an RA from lladdr2 with lesser lifetime. - const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) - select { - case <-ndpDisp.routerC: - t.Fatal("Should not receive a router event when updating lifetimes for known routers") - default: - } + // Rx an RA from lladdr2 with zero lifetime. It should not be + // remembered. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router with 0 lifetime") + default: + } - // Wait for lladdr2's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + // Rx an RA from lladdr2 with a huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) - // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + // Rx an RA from another router (lladdr3) with non-zero lifetime. + const l3LifetimeSeconds = 6 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) + expectRouterEvent(llAddr3, true) - // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - expectRouterEvent(llAddr2, false) + // Rx an RA from lladdr2 with lesser lifetime. + const l2LifetimeSeconds = 2 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) + select { + case <-ndpDisp.routerC: + t.Fatal("Should not receive a router event when updating lifetimes for known routers") + default: + } - // Wait for lladdr3's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + // Wait for lladdr2's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + + // Rx an RA from lladdr2 with huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) + + // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + expectRouterEvent(llAddr2, false) + + // Wait for lladdr3's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + }) } // TestRouterDiscoveryMaxRouters tests that only @@ -1351,7 +1493,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, }, NDPDisp: &ndpDisp, @@ -1390,57 +1532,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } } -// TestNoPrefixDiscovery tests that prefix discovery will not be performed if -// configured not to. -func TestNoPrefixDiscovery(t *testing.T) { - prefix := tcpip.AddressWithPrefix{ - Address: testutil.MustParse6("102:304:506:708::"), - PrefixLen: 64, - } - - // Being configured to discover prefixes means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // prefix discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - DiscoverOnLinkPrefixes: discover, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) - - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix when configured not to") - default: - } - }) - } -} - // Check e to make sure that the event is for prefix on nic with ID 1, and the // discovered flag set to discovered. func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { @@ -1459,8 +1550,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverOnLinkPrefixes: true, }, NDPDisp: &ndpDisp, @@ -1498,87 +1588,93 @@ func TestPrefixDiscovery(t *testing.T) { prefix2, subnet2, _ := prefixSubnetAddr(1, "") prefix3, subnet3, _ := prefixSubnetAddr(2, "") - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { - t.Helper() + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { + t.Helper() - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") } - default: - t.Fatal("expected prefix discovery event") } - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix with 0 lifetime") - default: - } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) - expectPrefixEvent(subnet1, true) + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix with 0 lifetime") + default: + } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) - expectPrefixEvent(subnet2, true) + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) + expectPrefixEvent(subnet1, true) - // Receive an RA with prefix3 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) - expectPrefixEvent(subnet3, true) + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) + expectPrefixEvent(subnet2, true) - // Receive an RA with prefix1 in a PI with lifetime = 0. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - expectPrefixEvent(subnet1, false) + // Receive an RA with prefix3 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) + expectPrefixEvent(subnet3, true) - // Receive an RA with prefix2 in a PI with lesser lifetime. - lifetime := uint32(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly received prefix event when updating lifetime") - default: - } + // Receive an RA with prefix1 in a PI with lifetime = 0. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + expectPrefixEvent(subnet1, false) - // Wait for prefix2's most recent invalidation job plus some buffer to - // expire. - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet2, false); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + // Receive an RA with prefix2 in a PI with lesser lifetime. + lifetime := uint32(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly received prefix event when updating lifetime") + default: + } + + // Wait for prefix2's most recent invalidation job plus some buffer to + // expire. + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet2, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for prefix discovery event") } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for prefix discovery event") - } - // Receive RA to invalidate prefix3. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) - expectPrefixEvent(subnet3, false) + // Receive RA to invalidate prefix3. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) + expectPrefixEvent(subnet3, false) + }) } func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { @@ -1607,7 +1703,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverOnLinkPrefixes: true, }, NDPDisp: &ndpDisp, @@ -1692,7 +1788,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: false, DiscoverOnLinkPrefixes: true, }, @@ -1757,53 +1853,6 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) return containsAddr(list, protocolAddress) } -// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. -func TestNoAutoGenAddr(t *testing.T) { - prefix, _, _ := prefixSubnetAddr(0, "") - - // Being configured to auto-generate addresses means handle and - // autogen are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, autogen = - // true and forwarding = false (the required configuration to do - // SLAAC) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - autogen := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - AutoGenGlobalAddresses: autogen, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) - - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when configured not to") - default: - } - }) - } -} - // Check e to make sure that the event is for addr on nic with ID 1, and the // event type is set to eventType. func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { @@ -1812,7 +1861,7 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. -func TestAutoGenAddr2(t *testing.T) { +func TestAutoGenAddr(t *testing.T) { const newMinVL = 2 newMinVLDuration := newMinVL * time.Second saved := ipv6.MinPrefixInformationValidLifetimeForUpdate @@ -1824,96 +1873,102 @@ func TestAutoGenAddr2(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") } - default: - t.Fatal("expected addr auto gen event") } - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with 0 lifetime") - default: - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with 0 lifetime") + default: + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") - default: - } + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") + default: + } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } - // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") - default: - } + // Refresh valid lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") + default: + } - // Wait for addr of prefix1 to be invalidated. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } + if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } + }) } func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string { @@ -2001,7 +2056,7 @@ func TestAutoGenTempAddr(t *testing.T) { RetransmitTimer: test.retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -2302,7 +2357,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { RetransmitTimer: retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -2389,7 +2444,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, @@ -2538,7 +2593,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, @@ -2739,7 +2794,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { Clock: clock, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: test.tempAddrs, AutoGenAddressConflictRetries: 1, @@ -2884,7 +2939,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: ndpDisp, @@ -3351,7 +3406,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3494,7 +3549,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3561,7 +3616,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3727,7 +3782,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3809,7 +3864,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3973,7 +4028,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { @@ -4000,7 +4055,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Temporary address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -4150,7 +4205,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, }, @@ -4278,7 +4333,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { RetransmitTimer: retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, }, @@ -4484,7 +4539,7 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -4535,7 +4590,7 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -4629,8 +4684,110 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { } } -// TestCleanupNDPState tests that all discovered routers and prefixes, and -// auto-generated addresses are invalidated when a NIC becomes a router. +func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { + const ( + lifetimeSeconds = 999 + nicID = 1 + ) + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenLinkLocal: true, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + }) + + e1 := channel.New(0, header.IPv6MinimumMTU, linkAddr1) + if err := s.CreateNIC(nicID, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + llAddr := tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen} + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, llAddr, newAddr); diff != "" { + t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", llAddr, nicID) + } + + prefix, subnet, addr := prefixSubnetAddr(0, linkAddr1) + e1.InjectInbound( + header.IPv6ProtocolNumber, + raBufWithPI( + llAddr3, + lifetimeSeconds, + prefix, + true, /* onLink */ + true, /* auto */ + lifetimeSeconds, + lifetimeSeconds, + ), + ) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID) + } + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID) + } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", addr, nicID) + } + + // Enabling or disabling forwarding should not invalidate discovered prefixes + // or routers, or auto-generated address. + for _, forwarding := range [...]bool{true, false} { + t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) { + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } + select { + case e := <-ndpDisp.routerC: + t.Errorf("unexpected router event = %#v", e) + default: + } + select { + case e := <-ndpDisp.prefixC: + t.Errorf("unexpected prefix event = %#v", e) + default: + } + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("unexpected auto-gen addr event = %#v", e) + default: + } + }) + } +} + func TestCleanupNDPState(t *testing.T) { const ( lifetimeSeconds = 5 @@ -4659,18 +4816,6 @@ func TestCleanupNDPState(t *testing.T) { maxAutoGenAddrEvents int skipFinalAddrCheck bool }{ - // A NIC should still keep its auto-generated link-local address when - // becoming a router. - { - name: "Enable forwarding", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, true) - }, - keepAutoGenLinkLocal: true, - maxAutoGenAddrEvents: 4, - }, - // A NIC should cleanup all NDP state when it is disabled. { name: "Disable NIC", @@ -4722,7 +4867,7 @@ func TestCleanupNDPState(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, DiscoverOnLinkPrefixes: true, AutoGenGlobalAddresses: true, @@ -4995,7 +5140,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -5186,96 +5331,127 @@ func TestRouterSolicitation(t *testing.T) { }, } + subTests := []struct { + name string + handleRAs ipv6.HandleRAsConfiguration + afterFirstRS func(*testing.T, *stack.Stack) + }{ + { + name: "Handle RAs when forwarding disabled", + handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + afterFirstRS: func(*testing.T, *stack.Stack) {}, + }, + + // Enabling forwarding when RAs are always configured to be handled + // should not stop router solicitations. + { + name: "Handle RAs always", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + afterFirstRS: func(t *testing.T, s *stack.Stack) { + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } + }, + }, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, - } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + clock := faketime.NewManualClock() + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + waitForPkt := func(timeout time.Duration) { + t.Helper() + + clock.Advance(timeout) + p, ok := e.Read() + if !ok { + t.Fatal("expected router solicitation packet") + } - clock.Advance(timeout) - p, ok := e.Read() - if !ok { - t.Fatal("expected router solicitation packet") - } + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } + // Make sure the right remote link address is used. + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.expectedSrcAddr), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), + ) - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) + } + } + waitForNothing := func(timeout time.Duration) { + t.Helper() - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() + clock.Advance(timeout) + if p, ok := e.Read(); ok { + t.Fatalf("unexpectedly got a packet = %#v", p) + } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: subTest.handleRAs, + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - clock.Advance(timeout) - if p, ok := e.Read(); ok { - t.Fatalf("unexpectedly got a packet = %#v", p) - } - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + } + } - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } - } + // Make sure each RS is sent at the right time. + remaining := test.maxRtrSolicit + if remaining > 0 { + waitForPkt(test.effectiveMaxRtrSolicitDelay) + remaining-- + } - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay) - remaining-- - } + subTest.afterFirstRS(t, s) - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) - waitForPkt(time.Nanosecond) - } else { - waitForPkt(test.effectiveRtrSolicitInt) - } - } + for ; remaining > 0; remaining-- { + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) + waitForPkt(time.Nanosecond) + } else { + waitForPkt(test.effectiveRtrSolicitInt) + } + } - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt) - } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay) - } + // Make sure no more RS. + if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { + waitForNothing(test.effectiveRtrSolicitInt) + } else { + waitForNothing(test.effectiveMaxRtrSolicitDelay) + } - if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { + t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + } + }) } }) } @@ -5300,11 +5476,17 @@ func TestStopStartSolicitingRouters(t *testing.T) { name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, false) + + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, false); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", ipv6.ProtocolNumber, err) + } }, stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, true) + + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } }, }, @@ -5373,6 +5555,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, MaxRtrSolicitations: maxRtrSolicitations, RtrSolicitationInterval: interval, MaxRtrSolicitationDelay: delay, diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 48bb75e2f..90881169d 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -15,8 +15,6 @@ package stack import ( - "bytes" - "encoding/binary" "fmt" "math" "math/rand" @@ -48,9 +46,6 @@ const ( // be sent to all nodes. testEntryBroadcastAddr = tcpip.Address("broadcast") - // testEntryLocalAddr is the source address of neighbor probes. - testEntryLocalAddr = tcpip.Address("local_addr") - // testEntryBroadcastLinkAddr is a special link address sent back to // multicast neighbor probes. testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") @@ -95,7 +90,7 @@ func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, cl randomGenerator: rng, }, id: 1, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), }, linkRes) return linkRes } @@ -106,20 +101,24 @@ type testEntryStore struct { entriesMap map[tcpip.Address]NeighborEntry } -func toAddress(i int) tcpip.Address { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint16(i)) - return tcpip.Address(buf.String()) +func toAddress(i uint16) tcpip.Address { + return tcpip.Address([]byte{ + 1, + 0, + byte(i >> 8), + byte(i), + }) } -func toLinkAddress(i int) tcpip.LinkAddress { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint32(i)) - return tcpip.LinkAddress(buf.String()) +func toLinkAddress(i uint16) tcpip.LinkAddress { + return tcpip.LinkAddress([]byte{ + 1, + 0, + 0, + 0, + byte(i >> 8), + byte(i), + }) } // newTestEntryStore returns a testEntryStore pre-populated with entries. @@ -127,7 +126,7 @@ func newTestEntryStore() *testEntryStore { store := &testEntryStore{ entriesMap: make(map[tcpip.Address]NeighborEntry), } - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) linkAddr := toLinkAddress(i) @@ -140,15 +139,15 @@ func newTestEntryStore() *testEntryStore { } // size returns the number of entries in the store. -func (s *testEntryStore) size() int { +func (s *testEntryStore) size() uint16 { s.mu.RLock() defer s.mu.RUnlock() - return len(s.entriesMap) + return uint16(len(s.entriesMap)) } // entry returns the entry at index i. Returns an empty entry and false if i is // out of bounds. -func (s *testEntryStore) entry(i int) (NeighborEntry, bool) { +func (s *testEntryStore) entry(i uint16) (NeighborEntry, bool) { return s.entryByAddr(toAddress(i)) } @@ -166,7 +165,7 @@ func (s *testEntryStore) entries() []NeighborEntry { entries := make([]NeighborEntry, 0, len(s.entriesMap)) s.mu.RLock() defer s.mu.RUnlock() - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) if entry, ok := s.entriesMap[addr]; ok { entries = append(entries, entry) @@ -176,7 +175,7 @@ func (s *testEntryStore) entries() []NeighborEntry { } // set modifies the link addresses of an entry. -func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { +func (s *testEntryStore) set(i uint16, linkAddr tcpip.LinkAddress) { addr := toAddress(i) s.mu.Lock() defer s.mu.Unlock() @@ -236,13 +235,6 @@ func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return 0 } -type entryEvent struct { - nicID tcpip.NICID - address tcpip.Address - linkAddr tcpip.LinkAddress - state NeighborState -} - func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() @@ -461,7 +453,7 @@ func newTestContext(c NUDConfigurations) testContext { } type overflowOptions struct { - startAtEntryIndex int + startAtEntryIndex uint16 wantStaticEntries []NeighborEntry } @@ -1068,7 +1060,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // periodically refreshes the frequently used entry. // Fill the neighbor cache to capacity - for i := 0; i < neighborCacheSize; i++ { + for i := uint16(0); i < neighborCacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) @@ -1084,7 +1076,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Keep adding more entries - for i := neighborCacheSize; i < linkRes.entries.size(); i++ { + for i := uint16(neighborCacheSize); i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { @@ -1556,14 +1548,14 @@ func TestNeighborCacheRetryResolution(t *testing.T) { func BenchmarkCacheClear(b *testing.B) { b.StopTimer() config := DefaultNUDConfigurations() - clock := &tcpip.StdClock{} + clock := tcpip.NewStdClock() linkRes := newTestNeighborResolver(nil, config, clock) linkRes.delay = 0 // Clear for every possible size of the cache - for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { + for cacheSize := uint16(0); cacheSize < neighborCacheSize; cacheSize++ { // Fill the neighbor cache to capacity. - for i := 0; i < cacheSize; i++ { + for i := uint16(0); i < cacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 6d95e1664..463d017fc 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -307,7 +307,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { err = e.cache.linkRes.LinkAddressRequest(addr, "" /* localAddr */, linkAddr) @@ -361,7 +361,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { e.dispatchAddEventLocked() case Unreachable: e.dispatchChangeEventLocked() - e.cache.nic.stats.Neighbor.UnreachableEntryLookups.Increment() + e.cache.nic.stats.neighbor.unreachableEntryLookups.Increment() } config := e.nudState.Config() @@ -378,7 +378,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { // As per RFC 4861 section 7.2.2: diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 1d39ee73d..c2a291244 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -36,11 +36,6 @@ const ( entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") - - // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, - // except where another value is explicitly used. It is chosen to match the - // MTU of loopback interfaces on Linux systems. - entryTestNetDefaultMTU = 65536 ) var ( @@ -196,13 +191,13 @@ func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.A // ResolveStaticAddress attempts to resolve address without sending requests. // It either resolves the name immediately or returns the empty LinkAddress. -func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*entryTestLinkResolver) ResolveStaticAddress(tcpip.Address) (tcpip.LinkAddress, bool) { return "", false } // LinkAddressProtocol returns the network protocol of the addresses this // resolver can resolve. -func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return entryTestNetNumber } @@ -219,7 +214,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e nudConfigs: c, randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())), }, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8d615500f..378389db2 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -51,7 +51,7 @@ type nic struct { name string context NICContext - stats NICStats + stats sharedStats // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. @@ -78,26 +78,13 @@ type nic struct { } } -// NICStats hold statistics for a NIC. -type NICStats struct { - Tx DirectionStats - Rx DirectionStats - - DisabledRx DirectionStats - - Neighbor NeighborStats -} - -func makeNICStats() NICStats { - var s NICStats - tcpip.InitStatCounters(reflect.ValueOf(&s).Elem()) - return s -} - -// DirectionStats includes packet and byte counts. -type DirectionStats struct { - Packets *tcpip.StatCounter - Bytes *tcpip.StatCounter +// makeNICStats initializes the NIC statistics and associates them to the global +// NIC statistics. +func makeNICStats(global tcpip.NICStats) sharedStats { + var stats sharedStats + tcpip.InitStatCounters(reflect.ValueOf(&stats.local).Elem()) + stats.init(&stats.local, &global) + return stats } type packetEndpointList struct { @@ -150,7 +137,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC id: id, name: name, context: ctx, - stats: makeNICStats(), + stats: makeNICStats(stack.Stats().NICs), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]*linkResolver), duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector), @@ -382,8 +369,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt return err } - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + n.stats.tx.packets.Increment() + n.stats.tx.bytes.IncrementBy(uint64(numBytes)) return nil } @@ -399,13 +386,13 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk } writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol) - n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) + n.stats.tx.packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { writtenBytes += pb.Size() } - n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + n.stats.tx.bytes.IncrementBy(uint64(writtenBytes)) return writtenPackets, err } @@ -718,18 +705,18 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if !enabled { n.mu.RUnlock() - n.stats.DisabledRx.Packets.Increment() - n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.disabledRx.packets.Increment() + n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size())) return } - n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.rx.packets.Increment() + n.stats.rx.bytes.IncrementBy(uint64(pkt.Data().Size())) networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL3ProtocolRcvdPackets.Increment() return } @@ -786,7 +773,7 @@ func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL4ProtocolRcvdPackets.Increment() return TransportPacketProtocolUnreachable } @@ -807,20 +794,20 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() // We consider a malformed transport packet handled because there is // nothing the caller can do. return TransportPacketHandled } } else if !transProto.Parse(pkt) { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } } srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } @@ -852,7 +839,7 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt // If it doesn't handle it then we should do so. switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled case UnknownDestinationPacketUnhandled: return TransportPacketDestinationPortUnreachable @@ -1000,3 +987,32 @@ func (n *nic) checkDuplicateAddress(protocol tcpip.NetworkProtocolNumber, addr t return d.CheckDuplicateAddress(addr, h), nil } + +func (n *nic) setForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { + ep := n.getNetworkEndpoint(protocol) + if ep == nil { + return &tcpip.ErrUnknownProtocol{} + } + + forwardingEP, ok := ep.(ForwardingNetworkEndpoint) + if !ok { + return &tcpip.ErrNotSupported{} + } + + forwardingEP.SetForwarding(enable) + return nil +} + +func (n *nic) forwarding(protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) { + ep := n.getNetworkEndpoint(protocol) + if ep == nil { + return false, &tcpip.ErrUnknownProtocol{} + } + + forwardingEP, ok := ep.(ForwardingNetworkEndpoint) + if !ok { + return false, &tcpip.ErrNotSupported{} + } + + return forwardingEP.Forwarding(), nil +} diff --git a/pkg/tcpip/stack/nic_stats.go b/pkg/tcpip/stack/nic_stats.go new file mode 100644 index 000000000..1773d5e8d --- /dev/null +++ b/pkg/tcpip/stack/nic_stats.go @@ -0,0 +1,74 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) + +type sharedStats struct { + local tcpip.NICStats + multiCounterNICStats +} + +// LINT.IfChange(multiCounterNICPacketStats) + +type multiCounterNICPacketStats struct { + packets tcpip.MultiCounterStat + bytes tcpip.MultiCounterStat +} + +func (m *multiCounterNICPacketStats) init(a, b *tcpip.NICPacketStats) { + m.packets.Init(a.Packets, b.Packets) + m.bytes.Init(a.Bytes, b.Bytes) +} + +// LINT.ThenChange(../../tcpip.go:NICPacketStats) + +// LINT.IfChange(multiCounterNICNeighborStats) + +type multiCounterNICNeighborStats struct { + unreachableEntryLookups tcpip.MultiCounterStat +} + +func (m *multiCounterNICNeighborStats) init(a, b *tcpip.NICNeighborStats) { + m.unreachableEntryLookups.Init(a.UnreachableEntryLookups, b.UnreachableEntryLookups) +} + +// LINT.ThenChange(../../tcpip.go:NICNeighborStats) + +// LINT.IfChange(multiCounterNICStats) + +type multiCounterNICStats struct { + unknownL3ProtocolRcvdPackets tcpip.MultiCounterStat + unknownL4ProtocolRcvdPackets tcpip.MultiCounterStat + malformedL4RcvdPackets tcpip.MultiCounterStat + tx multiCounterNICPacketStats + rx multiCounterNICPacketStats + disabledRx multiCounterNICPacketStats + neighbor multiCounterNICNeighborStats +} + +func (m *multiCounterNICStats) init(a, b *tcpip.NICStats) { + m.unknownL3ProtocolRcvdPackets.Init(a.UnknownL3ProtocolRcvdPackets, b.UnknownL3ProtocolRcvdPackets) + m.unknownL4ProtocolRcvdPackets.Init(a.UnknownL4ProtocolRcvdPackets, b.UnknownL4ProtocolRcvdPackets) + m.malformedL4RcvdPackets.Init(a.MalformedL4RcvdPackets, b.MalformedL4RcvdPackets) + m.tx.init(&a.Tx, &b.Tx) + m.rx.init(&a.Rx, &b.Rx) + m.disabledRx.init(&a.DisabledRx, &b.DisabledRx) + m.neighbor.init(&a.Neighbor, &b.Neighbor) +} + +// LINT.ThenChange(../../tcpip.go:NICStats) diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 8a3005295..5cb342f78 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,11 +15,13 @@ package stack import ( + "reflect" "testing" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) @@ -171,19 +173,19 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. nic := nic{ - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } - if got := nic.stats.DisabledRx.Packets.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 0 { t.Errorf("got DisabledRx.Packets = %d, want = 0", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 0 { t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } @@ -195,16 +197,28 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), })) - if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 4 { t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } } + +func TestMultiCounterStatsInitialization(t *testing.T) { + global := tcpip.NICStats{}.FillIn() + nic := nic{ + stats: makeNICStats(global), + } + multi := nic.stats.multiCounterNICStats + local := nic.stats.local + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&multi).Elem(), []reflect.Value{reflect.ValueOf(&local).Elem(), reflect.ValueOf(&global).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 5a94e9ac6..02f905351 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -323,26 +323,21 @@ type Rand interface { type NUDState struct { rng Rand - // mu protects the fields below. - // - // It is necessary for NUDState to handle its own locking since neighbor - // entries may access the NUD state from within the goroutine spawned by - // time.AfterFunc(). This goroutine may run concurrently with the main - // process for controlling the neighbor cache and would otherwise introduce - // race conditions if NUDState was not locked properly. - mu sync.RWMutex - - config NUDConfigurations - - // reachableTime is the duration to wait for a REACHABLE entry to - // transition into STALE after inactivity. This value is calculated with - // the algorithm defined in RFC 4861 section 6.3.2. - reachableTime time.Duration - - expiration time.Time - prevBaseReachableTime time.Duration - prevMinRandomFactor float32 - prevMaxRandomFactor float32 + mu struct { + sync.RWMutex + + config NUDConfigurations + + // reachableTime is the duration to wait for a REACHABLE entry to + // transition into STALE after inactivity. This value is calculated with + // the algorithm defined in RFC 4861 section 6.3.2. + reachableTime time.Duration + + expiration time.Time + prevBaseReachableTime time.Duration + prevMinRandomFactor float32 + prevMaxRandomFactor float32 + } } // NewNUDState returns new NUDState using c as configuration and the specified @@ -351,7 +346,7 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { s := &NUDState{ rng: rng, } - s.config = c + s.mu.config = c return s } @@ -359,14 +354,14 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { func (s *NUDState) Config() NUDConfigurations { s.mu.RLock() defer s.mu.RUnlock() - return s.config + return s.mu.config } // SetConfig replaces the existing NUD configurations with c. func (s *NUDState) SetConfig(c NUDConfigurations) { s.mu.Lock() defer s.mu.Unlock() - s.config = c + s.mu.config = c } // ReachableTime returns the duration to wait for a REACHABLE entry to @@ -377,13 +372,13 @@ func (s *NUDState) ReachableTime() time.Duration { s.mu.Lock() defer s.mu.Unlock() - if time.Now().After(s.expiration) || - s.config.BaseReachableTime != s.prevBaseReachableTime || - s.config.MinRandomFactor != s.prevMinRandomFactor || - s.config.MaxRandomFactor != s.prevMaxRandomFactor { + if time.Now().After(s.mu.expiration) || + s.mu.config.BaseReachableTime != s.mu.prevBaseReachableTime || + s.mu.config.MinRandomFactor != s.mu.prevMinRandomFactor || + s.mu.config.MaxRandomFactor != s.mu.prevMaxRandomFactor { s.recomputeReachableTimeLocked() } - return s.reachableTime + return s.mu.reachableTime } // recomputeReachableTimeLocked forces a recalculation of ReachableTime using @@ -408,23 +403,23 @@ func (s *NUDState) ReachableTime() time.Duration { // // s.mu MUST be locked for writing. func (s *NUDState) recomputeReachableTimeLocked() { - s.prevBaseReachableTime = s.config.BaseReachableTime - s.prevMinRandomFactor = s.config.MinRandomFactor - s.prevMaxRandomFactor = s.config.MaxRandomFactor + s.mu.prevBaseReachableTime = s.mu.config.BaseReachableTime + s.mu.prevMinRandomFactor = s.mu.config.MinRandomFactor + s.mu.prevMaxRandomFactor = s.mu.config.MaxRandomFactor - randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor) + randomFactor := s.mu.config.MinRandomFactor + s.rng.Float32()*(s.mu.config.MaxRandomFactor-s.mu.config.MinRandomFactor) // Check for overflow, given that minRandomFactor and maxRandomFactor are // guaranteed to be positive numbers. - if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) { - s.reachableTime = time.Duration(math.MaxInt64) + if math.MaxInt64/randomFactor < float32(s.mu.config.BaseReachableTime) { + s.mu.reachableTime = time.Duration(math.MaxInt64) } else if randomFactor == 1 { // Avoid loss of precision when a large base reachable time is used. - s.reachableTime = s.config.BaseReachableTime + s.mu.reachableTime = s.mu.config.BaseReachableTime } else { - reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor) - s.reachableTime = time.Duration(reachableTime) + reachableTime := int64(float32(s.mu.config.BaseReachableTime) * randomFactor) + s.mu.reachableTime = time.Duration(reachableTime) } - s.expiration = time.Now().Add(2 * time.Hour) + s.mu.expiration = time.Now().Add(2 * time.Hour) } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index e1253f310..6ba97d626 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -28,17 +28,15 @@ import ( ) const ( - defaultBaseReachableTime = 30 * time.Second - minimumBaseReachableTime = time.Millisecond - defaultMinRandomFactor = 0.5 - defaultMaxRandomFactor = 1.5 - defaultRetransmitTimer = time.Second - minimumRetransmitTimer = time.Millisecond - defaultDelayFirstProbeTime = 5 * time.Second - defaultMaxMulticastProbes = 3 - defaultMaxUnicastProbes = 3 - defaultMaxAnycastDelayTime = time.Second - defaultMaxReachbilityConfirmations = 3 + defaultBaseReachableTime = 30 * time.Second + minimumBaseReachableTime = time.Millisecond + defaultMinRandomFactor = 0.5 + defaultMaxRandomFactor = 1.5 + defaultRetransmitTimer = time.Second + minimumRetransmitTimer = time.Millisecond + defaultDelayFirstProbeTime = 5 * time.Second + defaultMaxMulticastProbes = 3 + defaultMaxUnicastProbes = 3 defaultFakeRandomNum = 0.5 ) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 646979d1e..4ca702121 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -16,9 +16,10 @@ package stack import ( "fmt" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" + tcpipbuffer "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -39,7 +40,11 @@ type PacketBufferOptions struct { // Data is the initial unparsed data for the new packet. If set, it will be // owned by the new packet. - Data buffer.VectorisedView + Data tcpipbuffer.VectorisedView + + // IsForwardedPacket identifies that the PacketBuffer being created is for a + // forwarded packet. + IsForwardedPacket bool } // A PacketBuffer contains all the data of a network packet. @@ -52,6 +57,34 @@ type PacketBufferOptions struct { // empty. Use of PacketBuffer in any other order is unsupported. // // PacketBuffer must be created with NewPacketBuffer. +// +// Internal structure: A PacketBuffer holds a pointer to buffer.Buffer, which +// exposes a logically-contiguous byte storage. The underlying storage structure +// is abstracted out, and should not be a concern here for most of the time. +// +// |- reserved ->| +// |--->| consumed (incoming) +// 0 V V +// +--------+----+----+--------------------+ +// | | | | current data ... | (buf) +// +--------+----+----+--------------------+ +// ^ | +// |<---| pushed (outgoing) +// +// When a PacketBuffer is created, a `reserved` header region can be specified, +// which stack pushes headers in this region for an outgoing packet. There could +// be no such region for an incoming packet, and `reserved` is 0. The value of +// `reserved` never changes in the entire lifetime of the packet. +// +// Outgoing Packet: When a header is pushed, `pushed` gets incremented by the +// pushed length, and the current value is stored for each header. PacketBuffer +// substracts this value from `reserved` to compute the starting offset of each +// header in `buf`. +// +// Incoming Packet: When a header is consumed (a.k.a. parsed), the current +// `consumed` value is stored for each header, and it gets incremented by the +// consumed length. PacketBuffer adds this value to `reserved` to compute the +// starting offset of each header in `buf`. type PacketBuffer struct { _ sync.NoCopy @@ -59,28 +92,16 @@ type PacketBuffer struct { // PacketBuffers. PacketBufferEntry - // data holds the payload of the packet. - // - // For inbound packets, Data is initially the whole packet. Then gets moved to - // headers via PacketHeader.Consume, when the packet is being parsed. - // - // For outbound packets, Data is the innermost layer, defined by the protocol. - // Headers are pushed in front of it via PacketHeader.Push. - // - // The bytes backing Data are immutable, a.k.a. users shouldn't write to its - // backing storage. - data buffer.VectorisedView + // buf is the underlying buffer for the packet. See struct level docs for + // details. + buf *buffer.Buffer + reserved int + pushed int + consumed int // headers stores metadata about each header. headers [numHeaderType]headerInfo - // header is the internal storage for outbound packets. Headers will be pushed - // (prepended) on this storage as the packet is being constructed. - // - // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and - // data are held in the same underlying buffer storage. - header buffer.Prependable - // NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty() // returns false. // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol @@ -127,10 +148,17 @@ type PacketBuffer struct { // NewPacketBuffer creates a new PacketBuffer with opts. func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { pk := &PacketBuffer{ - data: opts.Data, + buf: &buffer.Buffer{}, } if opts.ReserveHeaderBytes != 0 { - pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) + pk.buf.AppendOwned(make([]byte, opts.ReserveHeaderBytes)) + pk.reserved = opts.ReserveHeaderBytes + } + for _, v := range opts.Data.Views() { + pk.buf.AppendOwned(v) + } + if opts.IsForwardedPacket { + pk.NetworkPacketInfo.IsForwardedPacket = opts.IsForwardedPacket } return pk } @@ -138,13 +166,13 @@ func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { // ReservedHeaderBytes returns the number of bytes initially reserved for // headers. func (pk *PacketBuffer) ReservedHeaderBytes() int { - return pk.header.UsedLength() + pk.header.AvailableLength() + return pk.reserved } // AvailableHeaderBytes returns the number of bytes currently available for // headers. This is relevant to PacketHeader.Push method only. func (pk *PacketBuffer) AvailableHeaderBytes() int { - return pk.header.AvailableLength() + return pk.reserved - pk.pushed } // LinkHeader returns the handle to link-layer header. @@ -173,24 +201,18 @@ func (pk *PacketBuffer) TransportHeader() PacketHeader { // HeaderSize returns the total size of all headers in bytes. func (pk *PacketBuffer) HeaderSize() int { - // Note for inbound packets (Consume called), headers are not stored in - // pk.header. Thus, calculation of size of each header is needed. - var size int - for i := range pk.headers { - size += len(pk.headers[i].buf) - } - return size + return pk.pushed + pk.consumed } // Size returns the size of packet in bytes. func (pk *PacketBuffer) Size() int { - return pk.HeaderSize() + pk.data.Size() + return int(pk.buf.Size()) - pk.headerOffset() } // MemSize returns the estimation size of the pk in memory, including backing // buffer data. func (pk *PacketBuffer) MemSize() int { - return pk.HeaderSize() + pk.data.MemSize() + packetBufferStructSize + return int(pk.buf.Size()) + packetBufferStructSize } // Data returns the handle to data portion of pk. @@ -199,61 +221,65 @@ func (pk *PacketBuffer) Data() PacketData { } // Views returns the underlying storage of the whole packet. -func (pk *PacketBuffer) Views() []buffer.View { - // Optimization for outbound packets that headers are in pk.header. - useHeader := true - for i := range pk.headers { - if !canUseHeader(&pk.headers[i]) { - useHeader = false - break - } - } +func (pk *PacketBuffer) Views() []tcpipbuffer.View { + var views []tcpipbuffer.View + offset := pk.headerOffset() + pk.buf.SubApply(offset, int(pk.buf.Size())-offset, func(v []byte) { + views = append(views, v) + }) + return views +} - dataViews := pk.data.Views() - - var vs []buffer.View - if useHeader { - vs = make([]buffer.View, 0, 1+len(dataViews)) - vs = append(vs, pk.header.View()) - } else { - vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews)) - for i := range pk.headers { - if v := pk.headers[i].buf; len(v) > 0 { - vs = append(vs, v) - } - } - } - return append(vs, dataViews...) +func (pk *PacketBuffer) headerOffset() int { + return pk.reserved - pk.pushed +} + +func (pk *PacketBuffer) headerOffsetOf(typ headerType) int { + return pk.reserved + pk.headers[typ].offset } -func canUseHeader(h *headerInfo) bool { - // h.offset will be negative if the header was pushed in to prependable - // portion, or doesn't matter when it's empty. - return len(h.buf) == 0 || h.offset < 0 +func (pk *PacketBuffer) dataOffset() int { + return pk.reserved + pk.consumed } -func (pk *PacketBuffer) push(typ headerType, size int) buffer.View { +func (pk *PacketBuffer) push(typ headerType, size int) tcpipbuffer.View { h := &pk.headers[typ] - if h.buf != nil { - panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + if h.length > 0 { + panic(fmt.Sprintf("push(%s, %d) called after previous push", typ, size)) + } + if pk.pushed+size > pk.reserved { + panic(fmt.Sprintf("push(%s, %d) overflows; pushed=%d reserved=%d", typ, size, pk.pushed, pk.reserved)) } - h.buf = buffer.View(pk.header.Prepend(size)) - h.offset = -pk.header.UsedLength() - return h.buf + pk.pushed += size + h.offset = -pk.pushed + h.length = size + return pk.headerView(typ) } -func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) { +func (pk *PacketBuffer) consume(typ headerType, size int) (v tcpipbuffer.View, consumed bool) { h := &pk.headers[typ] - if h.buf != nil { + if h.length > 0 { panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) } - v, ok := pk.data.PullUp(size) + if pk.reserved+pk.consumed+size > int(pk.buf.Size()) { + return nil, false + } + h.offset = pk.consumed + h.length = size + pk.consumed += size + return pk.headerView(typ), true +} + +func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View { + h := &pk.headers[typ] + if h.length == 0 { + return nil + } + v, ok := pk.buf.PullUp(pk.headerOffsetOf(typ), h.length) if !ok { - return + panic("PullUp failed") } - pk.data.TrimFront(size) - h.buf = v - return h.buf, true + return v } // Clone makes a shallow copy of pk. @@ -263,9 +289,11 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum func (pk *PacketBuffer) Clone() *PacketBuffer { return &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, - data: pk.data.Clone(nil), + buf: pk.buf, + reserved: pk.reserved, + pushed: pk.pushed, + consumed: pk.consumed, headers: pk.headers, - header: pk.header, Hash: pk.Hash, Owner: pk.Owner, GSOOptions: pk.GSOOptions, @@ -299,9 +327,11 @@ func (pk *PacketBuffer) Network() header.Network { // See PacketBuffer.Data for details about how a packet buffer holds an inbound // packet. func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { - newPk := NewPacketBuffer(PacketBufferOptions{ - Data: buffer.NewVectorisedView(pk.Size(), pk.Views()), - }) + newPk := &PacketBuffer{ + buf: pk.buf, + // Treat unfilled header portion as reserved. + reserved: pk.AvailableHeaderBytes(), + } // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to // maintain this flag in the packet. Currently conntrack needs this flag to // tell if a noop connection should be inserted at Input hook. Once conntrack @@ -315,15 +345,12 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { // headerInfo stores metadata about a header in a packet. type headerInfo struct { - // buf is the memorized slice for both prepended and consumed header. - // When header is prepended, buf serves as memorized value, which is a slice - // of pk.header. When header is consumed, buf is the slice pulled out from - // pk.Data, which is the only place to hold this header. - buf buffer.View - - // offset will be a negative number denoting the offset where this header is - // from the end of pk.header, if it is prepended. Otherwise, zero. + // offset is the offset of the header in pk.buf relative to + // pk.buf[pk.reserved]. See the PacketBuffer struct for details. offset int + + // length is the length of this header. + length int } // PacketHeader is a handle object to a header in the underlying packet. @@ -333,14 +360,14 @@ type PacketHeader struct { } // View returns the underlying storage of h. -func (h PacketHeader) View() buffer.View { - return h.pk.headers[h.typ].buf +func (h PacketHeader) View() tcpipbuffer.View { + return h.pk.headerView(h.typ) } // Push pushes size bytes in the front of its residing packet, and returns the // backing storage. Callers may only call one of Push or Consume once on each // header in the lifetime of the underlying packet. -func (h PacketHeader) Push(size int) buffer.View { +func (h PacketHeader) Push(size int) tcpipbuffer.View { return h.pk.push(h.typ, size) } @@ -349,7 +376,7 @@ func (h PacketHeader) Push(size int) buffer.View { // size, consumed will be false, and the state of h will not be affected. // Callers may only call one of Push or Consume once on each header in the // lifetime of the underlying packet. -func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { +func (h PacketHeader) Consume(size int) (v tcpipbuffer.View, consumed bool) { return h.pk.consume(h.typ, size) } @@ -360,54 +387,84 @@ type PacketData struct { // PullUp returns a contiguous view of size bytes from the beginning of d. // Callers should not write to or keep the view for later use. -func (d PacketData) PullUp(size int) (buffer.View, bool) { - return d.pk.data.PullUp(size) +func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) { + return d.pk.buf.PullUp(d.pk.dataOffset(), size) } -// TrimFront removes count from the beginning of d. It panics if count > -// d.Size(). -func (d PacketData) TrimFront(count int) { - d.pk.data.TrimFront(count) +// DeleteFront removes count from the beginning of d. It panics if count > +// d.Size(). All backing storage references after the front of the d are +// invalidated. +func (d PacketData) DeleteFront(count int) { + if !d.pk.buf.Remove(d.pk.dataOffset(), count) { + panic("count > d.Size()") + } } // CapLength reduces d to at most length bytes. func (d PacketData) CapLength(length int) { - d.pk.data.CapLength(length) + if length < 0 { + panic("length < 0") + } + if currLength := d.Size(); currLength > length { + trim := currLength - length + d.pk.buf.Remove(int(d.pk.buf.Size())-trim, trim) + } } // Views returns the underlying storage of d in a slice of Views. Caller should // not modify the returned slice. -func (d PacketData) Views() []buffer.View { - return d.pk.data.Views() +func (d PacketData) Views() []tcpipbuffer.View { + var views []tcpipbuffer.View + offset := d.pk.dataOffset() + d.pk.buf.SubApply(offset, int(d.pk.buf.Size())-offset, func(v []byte) { + views = append(views, v) + }) + return views } // AppendView appends v into d, taking the ownership of v. -func (d PacketData) AppendView(v buffer.View) { - d.pk.data.AppendView(v) +func (d PacketData) AppendView(v tcpipbuffer.View) { + d.pk.buf.AppendOwned(v) } -// ReadFromData moves at most count bytes from the beginning of srcData to the -// end of d and returns the number of bytes moved. -func (d PacketData) ReadFromData(srcData PacketData, count int) int { - return srcData.pk.data.ReadToVV(&d.pk.data, count) +// MergeFragment appends the data portion of frag to dst. It takes ownership of +// frag and frag should not be used again. +func MergeFragment(dst, frag *PacketBuffer) { + frag.buf.TrimFront(int64(frag.dataOffset())) + dst.buf.Merge(frag.buf) } // ReadFromVV moves at most count bytes from the beginning of srcVV to the end // of d and returns the number of bytes moved. -func (d PacketData) ReadFromVV(srcVV *buffer.VectorisedView, count int) int { - return srcVV.ReadToVV(&d.pk.data, count) +func (d PacketData) ReadFromVV(srcVV *tcpipbuffer.VectorisedView, count int) int { + done := 0 + for _, v := range srcVV.Views() { + if len(v) < count { + count -= len(v) + done += len(v) + d.pk.buf.AppendOwned(v) + } else { + v = v[:count] + count -= len(v) + done += len(v) + d.pk.buf.Append(v) + break + } + } + srcVV.TrimFront(done) + return done } // Size returns the number of bytes in the data payload of the packet. func (d PacketData) Size() int { - return d.pk.data.Size() + return int(d.pk.buf.Size()) - d.pk.dataOffset() } // AsRange returns a Range representing the current data payload of the packet. func (d PacketData) AsRange() Range { return Range{ pk: d.pk, - offset: d.pk.HeaderSize(), + offset: d.pk.dataOffset(), length: d.Size(), } } @@ -417,17 +474,12 @@ func (d PacketData) AsRange() Range { // // This method exists for compatibility between PacketBuffer and VectorisedView. // It may be removed later and should be used with care. -func (d PacketData) ExtractVV() buffer.VectorisedView { - return d.pk.data -} - -// Replace replaces the data portion of the packet with vv, taking the ownership -// of vv. -// -// This method exists for compatibility between PacketBuffer and VectorisedView. -// It may be removed later and should be used with care. -func (d PacketData) Replace(vv buffer.VectorisedView) { - d.pk.data = vv +func (d PacketData) ExtractVV() tcpipbuffer.VectorisedView { + var vv tcpipbuffer.VectorisedView + d.pk.buf.SubApply(d.pk.dataOffset(), d.pk.Size(), func(v []byte) { + vv.AppendView(v) + }) + return vv } // Range represents a contiguous subportion of a PacketBuffer. @@ -471,9 +523,9 @@ func (r Range) Capped(max int) Range { // AsView returns the backing storage of r if possible. It will allocate a new // View if r spans multiple pieces internally. Caller should not write to the // returned View in any way. -func (r Range) AsView() buffer.View { +func (r Range) AsView() tcpipbuffer.View { var allocated bool - var v buffer.View + var v tcpipbuffer.View r.iterate(func(b []byte) { if v == nil { // v has not been assigned, allowing first view to be returned. @@ -494,7 +546,7 @@ func (r Range) AsView() buffer.View { } // ToOwnedView returns a owned copy of data in r. -func (r Range) ToOwnedView() buffer.View { +func (r Range) ToOwnedView() tcpipbuffer.View { if r.length == 0 { return nil } @@ -515,63 +567,7 @@ func (r Range) Checksum() uint16 { // iterate calls fn for each piece in r. fn is always called with a non-empty // slice. func (r Range) iterate(fn func([]byte)) { - w := window{ - offset: r.offset, - length: r.length, - } - // Header portion. - for i := range r.pk.headers { - if b := w.process(r.pk.headers[i].buf); len(b) > 0 { - fn(b) - } - if w.isDone() { - break - } - } - // Data portion. - if !w.isDone() { - for _, v := range r.pk.data.Views() { - if b := w.process(v); len(b) > 0 { - fn(b) - } - if w.isDone() { - break - } - } - } -} - -// window represents contiguous region of byte stream. User would call process() -// to input bytes, and obtain a subslice that is inside the window. -type window struct { - offset int - length int -} - -// isDone returns true if the window has passed and further process() calls will -// always return an empty slice. This can be used to end processing early. -func (w *window) isDone() bool { - return w.length == 0 -} - -// process feeds b in and returns a subslice that is inside the window. The -// returned slice will be a subslice of b, and it does not keep b after method -// returns. This method may return an empty slice if nothing in b is inside the -// window. -func (w *window) process(b []byte) (inWindow []byte) { - if w.offset >= len(b) { - w.offset -= len(b) - return nil - } - if w.offset > 0 { - b = b[w.offset:] - w.offset = 0 - } - if w.length < len(b) { - b = b[:w.length] - } - w.length -= len(b) - return b + r.pk.buf.SubApply(r.offset, r.length, fn) } // PayloadSince returns packet payload starting from and including a particular @@ -579,21 +575,14 @@ func (w *window) process(b []byte) (inWindow []byte) { // // The returned View is owned by the caller - its backing buffer is separate // from the packet header's underlying packet buffer. -func PayloadSince(h PacketHeader) buffer.View { - size := h.pk.data.Size() - for _, hinfo := range h.pk.headers[h.typ:] { - size += len(hinfo.buf) +func PayloadSince(h PacketHeader) tcpipbuffer.View { + offset := h.pk.headerOffset() + for i := headerType(0); i < h.typ; i++ { + offset += h.pk.headers[i].length } - - v := make(buffer.View, 0, size) - - for _, hinfo := range h.pk.headers[h.typ:] { - v = append(v, hinfo.buf...) - } - - for _, view := range h.pk.data.Views() { - v = append(v, view...) - } - - return v + return Range{ + pk: h.pk, + offset: offset, + length: int(h.pk.buf.Size()) - offset, + }.ToOwnedView() } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index 6728370c3..a8da34992 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -112,23 +112,13 @@ func TestPacketHeaderPush(t *testing.T) { if got, want := pk.Size(), allHdrSize+len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - checkData(t, pk, test.data) - checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), - concatViews(test.link, test.network, test.transport, test.data)) - // Check the after values for each header. - checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link) - checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network) - checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport) - // Check the after values for PayloadSince. - checkViewEqual(t, "After PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(test.link, test.network, test.transport, test.data)) - checkViewEqual(t, "After PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(test.network, test.transport, test.data)) - checkViewEqual(t, "After PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(test.transport, test.data)) + // Check the after state. + checkPacketContents(t, "After ", pk, packetContents{ + link: test.link, + network: test.network, + transport: test.transport, + data: test.data, + }) }) } } @@ -199,29 +189,13 @@ func TestPacketHeaderConsume(t *testing.T) { if got, want := pk.Size(), len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - // After state of pk. - var ( - link = test.data[:test.link] - network = test.data[test.link:][:test.network] - transport = test.data[test.link+test.network:][:test.transport] - payload = test.data[allHdrSize:] - ) - checkData(t, pk, payload) - checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) - // Check the after values for each header. - checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) - checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network) - checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport) - // Check the after values for PayloadSince. - checkViewEqual(t, "After PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(link, network, transport, payload)) - checkViewEqual(t, "After PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(network, transport, payload)) - checkViewEqual(t, "After PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(transport, payload)) + // Check the after state of pk. + checkPacketContents(t, "After ", pk, packetContents{ + link: test.data[:test.link], + network: test.data[test.link:][:test.network], + transport: test.data[test.link+test.network:][:test.transport], + data: test.data[allHdrSize:], + }) }) } } @@ -252,6 +226,70 @@ func TestPacketHeaderConsumeDataTooShort(t *testing.T) { }) } +// This is a very obscure use-case seen in the code that verifies packets +// before sending them out. It tries to parse the headers to verify. +// PacketHeader was initially not designed to mix Push() and Consume(), but it +// works and it's been relied upon. Include a test here. +func TestPacketHeaderPushConsumeMixed(t *testing.T) { + link := makeView(10) + network := makeView(20) + data := makeView(30) + + initData := append([]byte(nil), network...) + initData = append(initData, data...) + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: len(link), + Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), + }) + + // 1. Consume network header + gotNetwork, ok := pk.NetworkHeader().Consume(len(network)) + if !ok { + t.Fatalf("pk.NetworkHeader().Consume(%d) = _, false; want _, true", len(network)) + } + checkViewEqual(t, "gotNetwork", gotNetwork, network) + + // 2. Push link header + copy(pk.LinkHeader().Push(len(link)), link) + + checkPacketContents(t, "" /* prefix */, pk, packetContents{ + link: link, + network: network, + data: data, + }) +} + +func TestPacketHeaderPushConsumeMixedTooLong(t *testing.T) { + link := makeView(10) + network := makeView(20) + data := makeView(30) + + initData := concatViews(network, data) + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: len(link), + Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), + }) + + // 1. Push link header + copy(pk.LinkHeader().Push(len(link)), link) + + checkPacketContents(t, "" /* prefix */, pk, packetContents{ + link: link, + data: initData, + }) + + // 2. Consume network header, with a number of bytes too large. + gotNetwork, ok := pk.NetworkHeader().Consume(len(initData) + 1) + if ok { + t.Fatalf("pk.NetworkHeader().Consume(%d) = %q, true; want _, false", len(initData)+1, gotNetwork) + } + + checkPacketContents(t, "" /* prefix */, pk, packetContents{ + link: link, + data: initData, + }) +} + func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { const headerSize = 10 @@ -397,11 +435,11 @@ func TestPacketBufferData(t *testing.T) { } }) - // TrimFront + // DeleteFront for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) { + t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) { pkt := tc.makePkt(t) - pkt.Data().TrimFront(n) + pkt.Data().DeleteFront(n) checkData(t, pkt, []byte(tc.data)[n:]) }) @@ -437,23 +475,8 @@ func TestPacketBufferData(t *testing.T) { checkData(t, pkt, []byte(tc.data+s)) }) - // ReadFromData/VV + // ReadFromVV for _, n := range []int{0, 1, 2, 7, 10, 14, 20} { - t.Run(fmt.Sprintf("ReadFromData%d", n), func(t *testing.T) { - s := "TO READ" - otherPkt := NewPacketBuffer(PacketBufferOptions{ - Data: vv(s, s), - }) - s += s - - pkt := tc.makePkt(t) - pkt.Data().ReadFromData(otherPkt.Data(), n) - - if n < len(s) { - s = s[:n] - } - checkData(t, pkt, []byte(tc.data+s)) - }) t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) { s := "TO READ" srcVV := vv(s, s) @@ -480,20 +503,41 @@ func TestPacketBufferData(t *testing.T) { t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want) } }) - - // Replace - t.Run("Replace", func(t *testing.T) { - s := "REPLACED" - - pkt := tc.makePkt(t) - pkt.Data().Replace(vv(s)) - - checkData(t, pkt, []byte(s)) - }) }) } } +type packetContents struct { + link buffer.View + network buffer.View + transport buffer.View + data buffer.View +} + +func checkPacketContents(t *testing.T, prefix string, pk *PacketBuffer, want packetContents) { + t.Helper() + // Headers. + checkPacketHeader(t, prefix+"pk.LinkHeader", pk.LinkHeader(), want.link) + checkPacketHeader(t, prefix+"pk.NetworkHeader", pk.NetworkHeader(), want.network) + checkPacketHeader(t, prefix+"pk.TransportHeader", pk.TransportHeader(), want.transport) + // Data. + checkData(t, pk, want.data) + // Whole packet. + checkViewEqual(t, prefix+"pk.Views()", + concatViews(pk.Views()...), + concatViews(want.link, want.network, want.transport, want.data)) + // PayloadSince. + checkViewEqual(t, prefix+"PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(want.link, want.network, want.transport, want.data)) + checkViewEqual(t, prefix+"PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(want.network, want.transport, want.data)) + checkViewEqual(t, prefix+"PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(want.transport, want.data)) +} + func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { t.Helper() reserved := opts.ReserveHeaderBytes @@ -510,19 +554,9 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO if got, want := pk.Size(), len(data); got != want { t.Errorf("Initial pk.Size() = %d, want %d", got, want) } - checkData(t, pk, data) - checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) - // Check the initial values for each header. - checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) - checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil) - checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil) - // Check the initial valies for PayloadSince. - checkViewEqual(t, "Initial PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), data) - checkViewEqual(t, "Initial PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), data) - checkViewEqual(t, "Initial PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), data) + checkPacketContents(t, "Initial ", pk, packetContents{ + data: data, + }) } func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { @@ -540,7 +574,7 @@ func checkViewEqual(t *testing.T, what string, got, want buffer.View) { func checkData(t *testing.T, pkt *PacketBuffer, want []byte) { t.Helper() if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) { - t.Errorf("pkt.Data().Views() = %x, want %x", got, want) + t.Errorf("pkt.Data().Views() = 0x%x, want 0x%x", got, want) } if got := pkt.Data().Size(); got != len(want) { t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want)) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 7ad206f6d..85bb87b4b 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -55,6 +55,9 @@ type NetworkPacketInfo struct { // LocalAddressBroadcast is true if the packet's local address is a broadcast // address. LocalAddressBroadcast bool + + // IsForwardedPacket is true if the packet is being forwarded. + IsForwardedPacket bool } // TransportErrorKind enumerates error types that are handled by the transport @@ -655,9 +658,9 @@ type IPNetworkEndpointStats interface { IPStats() *tcpip.IPStats } -// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets. -type ForwardingNetworkProtocol interface { - NetworkProtocol +// ForwardingNetworkEndpoint is a network endpoint that may forward packets. +type ForwardingNetworkEndpoint interface { + NetworkEndpoint // Forwarding returns the forwarding configuration. Forwarding() bool @@ -756,11 +759,6 @@ const ( CapabilitySaveRestore CapabilityDisconnectOk CapabilityLoopback - CapabilityHardwareGSO - - // CapabilitySoftwareGSO indicates the link endpoint supports of sending - // multiple packets using a single call (LinkEndpoint.WritePackets). - CapabilitySoftwareGSO ) // NetworkLinkEndpoint is a data-link layer that supports sending network @@ -1047,10 +1045,29 @@ type GSO struct { MaxSize uint32 } +// SupportedGSO returns the type of segmentation offloading supported. +type SupportedGSO int + +const ( + // GSONotSupported indicates that segmentation offloading is not supported. + GSONotSupported SupportedGSO = iota + + // HWGSOSupported indicates that segmentation offloading may be performed by + // the hardware. + HWGSOSupported + + // SWGSOSupported indicates that segmentation offloading may be performed in + // software. + SWGSOSupported +) + // GSOEndpoint provides access to GSO properties. type GSOEndpoint interface { // GSOMaxSize returns the maximum GSO packet size. GSOMaxSize() uint32 + + // SupportedGSO returns the supported segmentation offloading. + SupportedGSO() SupportedGSO } // SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 4ecde5995..f17c04277 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -300,12 +300,18 @@ func (r *Route) RequiresTXTransportChecksum() bool { // HasSoftwareGSOCapability returns true if the route supports software GSO. func (r *Route) HasSoftwareGSOCapability() bool { - return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0 + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { + return gso.SupportedGSO() == SWGSOSupported + } + return false } // HasHardwareGSOCapability returns true if the route supports hardware GSO. func (r *Route) HasHardwareGSOCapability() bool { - return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0 + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { + return gso.SupportedGSO() == HWGSOSupported + } + return false } // HasSaveRestoreCapability returns true if the route supports save/restore. @@ -440,7 +446,7 @@ func (r *Route) isValidForOutgoingRLocked() bool { // If the source NIC and outgoing NIC are different, make sure the stack has // forwarding enabled, or the packet will be handled locally. - if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) { + if r.outgoingNIC != r.localAddressNIC && !isNICForwarding(r.localAddressNIC, r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) { return false } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 843118b13..72760a4a7 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -29,6 +29,7 @@ import ( "time" "golang.org/x/time/rate" + "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -39,13 +40,6 @@ import ( ) const ( - // ageLimit is set to the same cache stale time used in Linux. - ageLimit = 1 * time.Minute - // resolutionTimeout is set to the same ARP timeout used in Linux. - resolutionTimeout = 1 * time.Second - // resolutionAttempts is set to the same ARP retries used in Linux. - resolutionAttempts = 3 - // DefaultTOS is the default type of service value for network endpoints. DefaultTOS = 0 ) @@ -65,10 +59,10 @@ type ResumableEndpoint interface { } // uniqueIDGenerator is a default unique ID generator. -type uniqueIDGenerator uint64 +type uniqueIDGenerator atomicbitops.AlignedAtomicUint64 func (u *uniqueIDGenerator) UniqueID() uint64 { - return atomic.AddUint64((*uint64)(u), 1) + return ((*atomicbitops.AlignedAtomicUint64)(u)).Add(1) } // Stack is a networking stack, with all supported protocols, NICs, and route @@ -94,8 +88,9 @@ type Stack struct { } } - mu sync.RWMutex - nics map[tcpip.NICID]*nic + mu sync.RWMutex + nics map[tcpip.NICID]*nic + defaultForwardingEnabled map[tcpip.NetworkProtocolNumber]struct{} // cleanupEndpointsMu protects cleanupEndpoints. cleanupEndpointsMu sync.Mutex @@ -322,7 +317,7 @@ func (*TransportEndpointInfo) IsEndpointInfo() {} func New(opts Options) *Stack { clock := opts.Clock if clock == nil { - clock = &tcpip.StdClock{} + clock = tcpip.NewStdClock() } if opts.UniqueID == nil { @@ -347,22 +342,23 @@ func New(opts Options) *Stack { } s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - nics: make(map[tcpip.NICID]*nic), - cleanupEndpoints: make(map[TransportEndpoint]struct{}), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - tables: opts.IPTables, - icmpRateLimiter: NewICMPRateLimiter(), - seed: generateRandUint32(), - nudConfigs: opts.NUDConfigs, - uniqueIDGenerator: opts.UniqueID, - nudDisp: opts.NUDDisp, - randomGenerator: mathrand.New(randSrc), - secureRNG: opts.SecureRNG, + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + nics: make(map[tcpip.NICID]*nic), + defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + tables: opts.IPTables, + icmpRateLimiter: NewICMPRateLimiter(), + seed: generateRandUint32(), + nudConfigs: opts.NUDConfigs, + uniqueIDGenerator: opts.UniqueID, + nudDisp: opts.NUDDisp, + randomGenerator: mathrand.New(randSrc), + secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, @@ -491,37 +487,61 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables packet forwarding between NICs for the -// passed protocol. -func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { - protocol, ok := s.networkProtocols[protocolNum] +// SetNICForwarding enables or disables packet forwarding on the specified NIC +// for the passed protocol. +func (s *Stack) SetNICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] if !ok { - return &tcpip.ErrUnknownProtocol{} + return &tcpip.ErrUnknownNICID{} } - forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) + return nic.setForwarding(protocol, enable) +} + +// NICForwarding returns the forwarding configuration for the specified NIC. +func (s *Stack) NICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] if !ok { - return &tcpip.ErrNotSupported{} + return false, &tcpip.ErrUnknownNICID{} } - forwardingProtocol.SetForwarding(enable) - return nil + return nic.forwarding(protocol) } -// Forwarding returns true if packet forwarding between NICs is enabled for the -// passed protocol. -func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { - protocol, ok := s.networkProtocols[protocolNum] - if !ok { - return false +// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the +// passed protocol and sets the default setting for newly created NICs. +func (s *Stack) SetForwardingDefaultAndAllNICs(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + doneOnce := false + for id, nic := range s.nics { + if err := nic.setForwarding(protocol, enable); err != nil { + // Expect forwarding to be settable on all interfaces if it was set on + // one. + if doneOnce { + panic(fmt.Sprintf("nic(id=%d).setForwarding(%d, %t): %s", id, protocol, enable, err)) + } + + return err + } + + doneOnce = true } - forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) - if !ok { - return false + if enable { + s.defaultForwardingEnabled[protocol] = struct{}{} + } else { + delete(s.defaultForwardingEnabled, protocol) } - return forwardingProtocol.Forwarding() + return nil } // PortRange returns the UDP and TCP inclusive range of ephemeral ports used in @@ -658,6 +678,11 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp } n := newNIC(s, id, opts.Name, ep, opts.Context) + for proto := range s.defaultForwardingEnabled { + if err := n.setForwarding(proto, true); err != nil { + panic(fmt.Sprintf("newNIC(%d, ...).setForwarding(%d, true): %s", id, proto, err)) + } + } s.nics[id] = n if !opts.Disabled { return n.enable() @@ -772,7 +797,7 @@ type NICInfo struct { // MTU is the maximum transmission unit. MTU uint32 - Stats NICStats + Stats tcpip.NICStats // NetworkStats holds the stats of each NetworkEndpoint bound to the NIC. NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats @@ -785,6 +810,10 @@ type NICInfo struct { // value sent in haType field of an ARP Request sent by this NIC and the // value expected in the haType field of an ARP response. ARPHardwareType header.ARPHardwareType + + // Forwarding holds the forwarding status for each network endpoint that + // supports forwarding. + Forwarding map[tcpip.NetworkProtocolNumber]bool } // HasNIC returns true if the NICID is defined in the stack. @@ -814,17 +843,33 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { netStats[proto] = netEP.Stats() } - nics[id] = NICInfo{ + info := NICInfo{ Name: nic.name, LinkAddress: nic.LinkEndpoint.LinkAddress(), ProtocolAddresses: nic.primaryAddresses(), Flags: flags, MTU: nic.LinkEndpoint.MTU(), - Stats: nic.stats, + Stats: nic.stats.local, NetworkStats: netStats, Context: nic.context, ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), + Forwarding: make(map[tcpip.NetworkProtocolNumber]bool), } + + for proto := range s.networkProtocols { + switch forwarding, err := nic.forwarding(proto); err.(type) { + case nil: + info.Forwarding[proto] = forwarding + case *tcpip.ErrUnknownProtocol: + panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID())) + case *tcpip.ErrNotSupported: + // Not all network protocols support forwarding. + default: + panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err)) + } + } + + nics[id] = info } return nics } @@ -1028,6 +1073,20 @@ func (s *Stack) HandleLocal() bool { return s.handleLocal } +func isNICForwarding(nic *nic, proto tcpip.NetworkProtocolNumber) bool { + switch forwarding, err := nic.forwarding(proto); err.(type) { + case nil: + return forwarding + case *tcpip.ErrUnknownProtocol: + panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID())) + case *tcpip.ErrNotSupported: + // Not all network protocols support forwarding. + return false + default: + panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err)) + } +} + // FindRoute creates a route to the given destination address, leaving through // the given NIC and local address (if provided). // @@ -1080,7 +1139,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n return nil, &tcpip.ErrNetworkUnreachable{} } - canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal + onlyGlobalAddresses := !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal // Find a route to the remote with the route table. var chosenRoute tcpip.Route @@ -1119,7 +1178,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n // requirement to do this from any RFC but simply a choice made to better // follow a strong host model which the netstack follows at the time of // writing. - if canForward && chosenRoute == (tcpip.Route{}) { + if onlyGlobalAddresses && chosenRoute == (tcpip.Route{}) && isNICForwarding(nic, netProto) { chosenRoute = route } } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 8ead3b8df..73e0f0d58 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -84,7 +84,8 @@ type fakeNetworkEndpoint struct { mu struct { sync.RWMutex - enabled bool + enabled bool + forwarding bool } nic stack.NetworkInterface @@ -138,11 +139,13 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data().PullUp(fakeNetHeaderLen) + hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen) if !ok { return } - pkt.Data().TrimFront(fakeNetHeaderLen) + // DeleteFront invalidates slices. Make a copy before trimming. + nb := append([]byte(nil), hdr...) + pkt.Data().DeleteFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), @@ -163,10 +166,6 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fakeNetHeaderLen } -func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -194,11 +193,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHe } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -225,11 +224,6 @@ type fakeNetworkProtocol struct { packetCount [10]int sendPacketCount [10]int defaultTTL uint8 - - mu struct { - sync.RWMutex - forwarding bool - } } func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -298,15 +292,15 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (f *fakeNetworkProtocol) Forwarding() bool { +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (f *fakeNetworkEndpoint) Forwarding() bool { f.mu.RLock() defer f.mu.RUnlock() return f.mu.forwarding } -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (f *fakeNetworkProtocol) SetForwarding(v bool) { +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (f *fakeNetworkEndpoint) SetForwarding(v bool) { f.mu.Lock() defer f.mu.Unlock() f.mu.forwarding = v @@ -465,14 +459,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer } } -func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) @@ -922,15 +916,15 @@ func TestRouteWithDownNIC(t *testing.T) { if err := test.downFn(s, nicID1); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID1, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) testSend(t, r2, ep2, buf) // Writes with Routes that use NIC2 after being brought down should fail. if err := test.downFn(s, nicID2); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID2, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) if upFn := test.upFn; upFn != nil { // Writes with Routes that use NIC1 after being brought up should @@ -943,7 +937,7 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) } }) } @@ -1068,7 +1062,7 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. err := s.RemoveAddress(1, localAddr) @@ -1120,8 +1114,8 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. { @@ -1142,7 +1136,7 @@ func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.A // No address given, verify that there is no address assigned to the NIC. for _, a := range info.ProtocolAddresses { if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { - t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) + t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, tcpip.AddressWithPrefix{}) } } return @@ -1222,7 +1216,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 2. Add Address, everything should work. @@ -1250,7 +1244,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 4. Add Address back, everything should work again. @@ -1289,8 +1283,8 @@ func TestEndpointExpiration(t *testing.T) { testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 7. Add Address back, everything should work again. @@ -1326,7 +1320,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } }) } @@ -1576,7 +1570,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, dstAddr, nil, &tcpip.ErrNoRoute{}) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -1617,7 +1611,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } } - protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} if err := s.AddProtocolAddress(1, protoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) } @@ -1643,12 +1637,12 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } func TestOutgoingBroadcastWithRouteTable(t *testing.T) { - defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + defaultAddr := tcpip.AddressWithPrefix{Address: header.IPv4Any} // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. - nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Addr := tcpip.AddressWithPrefix{Address: "\xc0\xa8\x01\x3a", PrefixLen: 24} nic1Gateway := testutil.MustParse4("192.168.1.1") // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. - nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Addr := tcpip.AddressWithPrefix{Address: "\x0a\x0a\x0a\x05", PrefixLen: 24} nic2Gateway := testutil.MustParse4("10.10.10.1") // Create a new stack with two NICs. @@ -1662,12 +1656,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err := s.CreateNIC(2, ep); err != nil { t.Fatalf("CreateNIC failed: %s", err) } - nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) } - nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) } @@ -1711,7 +1705,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // 2. Case: Having an explicit route for broadcast will select that one. rt = append( []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, }, rt..., ) @@ -2051,7 +2045,7 @@ func TestAddAddress(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } @@ -2115,7 +2109,7 @@ func TestAddAddressWithOptions(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } } @@ -2236,7 +2230,7 @@ func TestCreateNICWithOptions(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { s := stack.New(stack.Options{}) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") for _, call := range test.calls { if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) @@ -2250,46 +2244,87 @@ func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed: ", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + + nics := []struct { + addr tcpip.Address + txByteCount int + rxByteCount int + }{ + { + addr: "\x01", + txByteCount: 30, + rxByteCount: 10, + }, + { + addr: "\x02", + txByteCount: 50, + rxByteCount: 20, + }, } - // Route all packets for address \x01 to NIC 1. - { - subnet, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) + + var txBytesTotal, rxBytesTotal, txPacketsTotal, rxPacketsTotal int + for i, nic := range nics { + nicid := tcpip.NICID(i) + ep := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { + t.Fatal("CreateNIC failed: ", err) + } + if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { + t.Fatal("AddAddress failed:", err) } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - // Send a packet to address 1. - buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) - } + { + subnet, err := tcpip.NewSubnet(nic.addr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicid}}) + } + + nicStats := s.NICInfo()[nicid].Stats - if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + // Inbound packet. + rxBuffer := buffer.NewView(nic.rxByteCount) + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: rxBuffer.ToVectorisedView(), + })) + if got, want := nicStats.Rx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) + } + if got, want := nicStats.Rx.Bytes.Value(), uint64(nic.rxByteCount); got != want { + t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + } + rxPacketsTotal++ + rxBytesTotal += nic.rxByteCount + + // Outbound packet. + txBuffer := buffer.NewView(nic.txByteCount) + actualTxLength := nic.txByteCount + fakeNetHeaderLen + if err := sendTo(s, nic.addr, txBuffer); err != nil { + t.Fatal("sendTo failed: ", err) + } + want := ep.Drain() + if got := nicStats.Tx.Packets.Value(); got != uint64(want) { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := nicStats.Tx.Bytes.Value(), uint64(actualTxLength); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + txPacketsTotal += want + txBytesTotal += actualTxLength } - payload := buffer.NewView(10) - // Write a packet out via the address for NIC 1 - if err := sendTo(s, "\x01", payload); err != nil { - t.Fatal("sendTo failed: ", err) + // Now verify that each NIC stats was correctly aggregated at the stack level. + if got, want := s.Stats().NICs.Rx.Packets.Value(), uint64(rxPacketsTotal); got != want { + t.Errorf("got s.Stats().NIC.Rx.Packets.Value() = %d, want = %d", got, want) } - want := uint64(ep1.Drain()) - if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) + if got, want := s.Stats().NICs.Rx.Bytes.Value(), uint64(rxBytesTotal); got != want { + t.Errorf("got s.Stats().Rx.Bytes.Value() = %d, want = %d", got, want) } - - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want { + if got, want := s.Stats().NICs.Tx.Packets.Value(), uint64(txPacketsTotal); got != want { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := s.Stats().NICs.Tx.Bytes.Value(), uint64(txBytesTotal); got != want { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } @@ -2318,7 +2353,7 @@ func TestNICContextPreservation(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{}) id := tcpip.NICID(1) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) } @@ -3020,7 +3055,7 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -3839,8 +3874,6 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { // TestAddRoute tests Stack.AddRoute func TestAddRoute(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) subnet1, err := tcpip.NewSubnet("\x00", "\x00") @@ -3877,8 +3910,6 @@ func TestAddRoute(t *testing.T) { // TestRemoveRoutes tests Stack.RemoveRoutes func TestRemoveRoutes(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) addressToRemove := tcpip.Address("\x01") @@ -4218,14 +4249,14 @@ func TestFindRouteWithForwarding(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) } - if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) } s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) - if r != nil { + if err == nil { defer r.Release() } if diff := cmp.Diff(test.findRouteErr, err); diff != "" { @@ -4273,8 +4304,8 @@ func TestFindRouteWithForwarding(t *testing.T) { // Disabling forwarding when the route is dependent on forwarding being // enabled should make the route invalid. - if err := s.SetForwarding(test.netCfg.proto, false); err != nil { - t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, false); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", test.netCfg.proto, err) } { err := send(r, data) |