diff options
Diffstat (limited to 'pkg/tcpip/stack/ndp_test.go')
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 249 |
1 files changed, 159 insertions, 90 deletions
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 011571106..c585b81b2 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1147,7 +1147,10 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on } func TestDynamicConfigurationsDisabled(t *testing.T) { - const nicID = 1 + const ( + nicID = 1 + maxRtrSolicitDelay = time.Second + ) prefix := tcpip.AddressWithPrefix{ Address: testutil.MustParse6("102:304:506:708::"), @@ -1206,7 +1209,12 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { } ndpConfigs := test.config(enable) ndpConfigs.HandleRAs = handle + ndpConfigs.MaxRtrSolicitations = 1 + ndpConfigs.RtrSolicitationInterval = maxRtrSolicitDelay + ndpConfigs.MaxRtrSolicitationDelay = maxRtrSolicitDelay + clock := faketime.NewManualClock() s := stack.New(stack.Options{ + Clock: clock, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, @@ -1216,15 +1224,12 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } - e := channel.New(0, 1280, linkAddr1) + e := channel.New(1, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - e.InjectInbound(header.IPv6ProtocolNumber, test.ra.Clone()) - - // Make sure that the unhandled RA stat is only incremented when RAs - // are configured to be unhandled. + handleRAsDisabled := handle == ipv6.HandlingRAsDisabled || forwarding ep, err := s.GetNetworkEndpoint(nicID, ipv6.ProtocolNumber) if err != nil { t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ipv6.ProtocolNumber, err) @@ -1234,16 +1239,41 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { if !ok { t.Fatalf("got v6Stats = %T, expected = %T", stats, v6Stats) } - want := uint64(0) - if handle == ipv6.HandlingRAsDisabled || forwarding { - want = 1 + + // Make sure that when handling RAs are enabled, we solicit routers. + clock.Advance(maxRtrSolicitDelay) + if got, want := v6Stats.ICMP.PacketsSent.RouterSolicit.Value(), boolToUint64(!handleRAsDisabled); got != want { + t.Errorf("got v6Stats.ICMP.PacketsSent.RouterSolicit.Value() = %d, want = %d", got, want) } - if got := v6Stats.UnhandledRouterAdvertisements.Value(); got != want { - t.Fatalf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) + if handleRAsDisabled { + if p, ok := e.Read(); ok { + t.Errorf("unexpectedly got a packet = %#v", p) + } + } else if p, ok := e.Read(); !ok { + t.Error("expected router solicitation packet") + } else if p.Proto != header.IPv6ProtocolNumber { + t.Errorf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } else { + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(nil)), + ) } - // Make sure we did not discover any routers or prefixes, or perform - // SLAAC. + // Make sure we do not discover any routers or prefixes, or perform + // SLAAC on reception of an RA. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra.Clone()) + // Make sure that the unhandled RA stat is only incremented when + // handling RAs is disabled. + if got, want := v6Stats.UnhandledRouterAdvertisements.Value(), boolToUint64(handleRAsDisabled); got != want { + t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) + } select { case e := <-ndpDisp.routerC: t.Errorf("unexpectedly discovered a router when configured not to: %#v", e) @@ -1265,6 +1295,13 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { } } +func boolToUint64(v bool) uint64 { + if v { + return 1 + } + return 0 +} + // Check e to make sure that the event is for addr on nic with ID 1, and the // discovered flag set to discovered. func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { @@ -5299,96 +5336,127 @@ func TestRouterSolicitation(t *testing.T) { }, } + subTests := []struct { + name string + handleRAs ipv6.HandleRAsConfiguration + afterFirstRS func(*testing.T, *stack.Stack) + }{ + { + name: "Handle RAs when forwarding disabled", + handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + afterFirstRS: func(*testing.T, *stack.Stack) {}, + }, + + // Enabling forwarding when RAs are always configured to be handled + // should not stop router solicitations. + { + name: "Handle RAs always", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + afterFirstRS: func(t *testing.T, s *stack.Stack) { + if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + } + }, + }, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, - } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + clock := faketime.NewManualClock() + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + waitForPkt := func(timeout time.Duration) { + t.Helper() + + clock.Advance(timeout) + p, ok := e.Read() + if !ok { + t.Fatal("expected router solicitation packet") + } - clock.Advance(timeout) - p, ok := e.Read() - if !ok { - t.Fatal("expected router solicitation packet") - } + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } + // Make sure the right remote link address is used. + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.expectedSrcAddr), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), + ) - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) + } + } + waitForNothing := func(timeout time.Duration) { + t.Helper() - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() + clock.Advance(timeout) + if p, ok := e.Read(); ok { + t.Fatalf("unexpectedly got a packet = %#v", p) + } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: subTest.handleRAs, + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - clock.Advance(timeout) - if p, ok := e.Read(); ok { - t.Fatalf("unexpectedly got a packet = %#v", p) - } - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + } + } - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } - } + // Make sure each RS is sent at the right time. + remaining := test.maxRtrSolicit + if remaining > 0 { + waitForPkt(test.effectiveMaxRtrSolicitDelay) + remaining-- + } - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay) - remaining-- - } + subTest.afterFirstRS(t, s) - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) - waitForPkt(time.Nanosecond) - } else { - waitForPkt(test.effectiveRtrSolicitInt) - } - } + for ; remaining > 0; remaining-- { + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) + waitForPkt(time.Nanosecond) + } else { + waitForPkt(test.effectiveRtrSolicitInt) + } + } - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt) - } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay) - } + // Make sure no more RS. + if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { + waitForNothing(test.effectiveRtrSolicitInt) + } else { + waitForNothing(test.effectiveMaxRtrSolicitDelay) + } - if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { + t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + } + }) } }) } @@ -5486,6 +5554,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, MaxRtrSolicitations: maxRtrSolicitations, RtrSolicitationInterval: interval, MaxRtrSolicitationDelay: delay, |