From 9e4a1e31d4fbf7d4439d503bf318517c92c8e885 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 8 Apr 2021 09:48:17 -0700 Subject: Join all routers group when forwarding is enabled See comments inline code for rationale. Test: ip_test.TestJoinLeaveAllRoutersGroup PiperOrigin-RevId: 367449434 --- pkg/tcpip/header/ipv6.go | 21 +++++++-- pkg/tcpip/network/ip_test.go | 78 +++++++++++++++++++++++++++++++ pkg/tcpip/network/ipv4/ipv4.go | 57 ++++++++++++++++++++-- pkg/tcpip/network/ipv6/ipv6.go | 63 +++++++++++++++++++++---- pkg/tcpip/network/ipv6/mld.go | 2 +- pkg/tcpip/network/ipv6/mld_test.go | 2 +- pkg/tcpip/network/ipv6/ndp.go | 8 ++-- pkg/tcpip/network/multicast_group_test.go | 6 +-- pkg/tcpip/stack/ndp_test.go | 6 +-- 9 files changed, 213 insertions(+), 30 deletions(-) diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 7711abec1..fa6ccff30 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -98,12 +98,27 @@ const ( // The address is ff02::1. IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - // IPv6AllRoutersMulticastAddress is a link-local multicast group that - // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // IPv6AllRoutersInterfaceLocalMulticastAddress is an interface-local + // multicast group that all IPv6 routers MUST join, as per RFC 4291, section + // 2.8. Packets destined to this address will reach the router on an + // interface. + // + // The address is ff01::2. + IPv6AllRoutersInterfaceLocalMulticastAddress tcpip.Address = "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + // IPv6AllRoutersLinkLocalMulticastAddress is a link-local multicast group + // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets // destined to this address will reach all routers on a link. // // The address is ff02::2. - IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + IPv6AllRoutersLinkLocalMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + // IPv6AllRoutersSiteLocalMulticastAddress is a site-local multicast group + // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all routers in a site. + // + // The address is ff05::2. + IPv6AllRoutersSiteLocalMulticastAddress tcpip.Address = "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200, // section 5: diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index a4edc69c7..58fd18af8 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -15,6 +15,7 @@ package ip_test import ( + "fmt" "strings" "testing" @@ -1938,3 +1939,80 @@ func TestICMPInclusionSize(t *testing.T) { }) } } + +func TestJoinLeaveAllRoutersGroup(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + protoFactory stack.NetworkProtocolFactory + allRoutersAddr tcpip.Address + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + protoFactory: ipv4.NewProtocol, + allRoutersAddr: header.IPv4AllRoutersGroup, + }, + { + name: "IPv6 Interface Local", + netProto: ipv6.ProtocolNumber, + protoFactory: ipv6.NewProtocol, + allRoutersAddr: header.IPv6AllRoutersInterfaceLocalMulticastAddress, + }, + { + name: "IPv6 Link Local", + netProto: ipv6.ProtocolNumber, + protoFactory: ipv6.NewProtocol, + allRoutersAddr: header.IPv6AllRoutersLinkLocalMulticastAddress, + }, + { + name: "IPv6 Site Local", + netProto: ipv6.ProtocolNumber, + protoFactory: ipv6.NewProtocol, + allRoutersAddr: header.IPv6AllRoutersSiteLocalMulticastAddress, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, nicDisabled := range [...]bool{true, false} { + t.Run(fmt.Sprintf("NIC Disabled = %t", nicDisabled), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + }) + opts := stack.NICOptions{Disabled: nicDisabled} + if err := s.CreateNICWithOptions(nicID, channel.New(0, 0, ""), opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err) + } + + if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { + t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) + } else if got { + t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) + } + + if err := s.SetForwarding(test.netProto, true); err != nil { + t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err) + } + if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { + t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) + } else if !got { + t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr) + } + + if err := s.SetForwarding(test.netProto, false); err != nil { + t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err) + } + if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { + t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) + } else if got { + t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1a5661ca4..6dc6fc9bd 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -150,6 +150,38 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { delete(p.mu.eps, nicID) } +// transitionForwarding transitions the endpoint's forwarding status to +// forwarding. +// +// Must only be called when the forwarding status changes. +func (e *endpoint) transitionForwarding(forwarding bool) { + e.mu.Lock() + defer e.mu.Unlock() + + if forwarding { + // There does not seem to be an RFC requirement for a node to join the all + // routers multicast address but + // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml + // specifies the address as a group for all routers on a subnet so we join + // the group here. + if err := e.joinGroupLocked(header.IPv4AllRoutersGroup); err != nil { + // joinGroupLocked only returns an error if the group address is not a + // valid IPv4 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err)) + } + + return + } + + switch err := e.leaveGroupLocked(header.IPv4AllRoutersGroup).(type) { + case nil: + case *tcpip.ErrBadLocalAddress: + // The endpoint may have already left the multicast group. + default: + panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err)) + } +} + // Enable implements stack.NetworkEndpoint. func (e *endpoint) Enable() tcpip.Error { e.mu.Lock() @@ -226,7 +258,7 @@ func (e *endpoint) disableLocked() { } // The endpoint may have already left the multicast group. - switch err := e.leaveGroupLocked(header.IPv4AllSystems); err.(type) { + switch err := e.leaveGroupLocked(header.IPv4AllSystems).(type) { case nil, *tcpip.ErrBadLocalAddress: default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) @@ -1168,12 +1200,27 @@ func (p *protocol) Forwarding() bool { return uint8(atomic.LoadUint32(&p.forwarding)) == 1 } +// setForwarding sets the forwarding status for the protocol. +// +// Returns true if the forwarding status was updated. +func (p *protocol) setForwarding(v bool) bool { + if v { + return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) + } + return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) +} + // SetForwarding implements stack.ForwardingNetworkProtocol. func (p *protocol) SetForwarding(v bool) { - if v { - atomic.StoreUint32(&p.forwarding, 1) - } else { - atomic.StoreUint32(&p.forwarding, 0) + p.mu.Lock() + defer p.mu.Unlock() + + if !p.setForwarding(v) { + return + } + + for _, ep := range p.mu.eps { + ep.transitionForwarding(v) } } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index c6d9d8f0d..2a2ad6482 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -410,22 +410,65 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t // // Must only be called when the forwarding status changes. func (e *endpoint) transitionForwarding(forwarding bool) { + allRoutersGroups := [...]tcpip.Address{ + header.IPv6AllRoutersInterfaceLocalMulticastAddress, + header.IPv6AllRoutersLinkLocalMulticastAddress, + header.IPv6AllRoutersSiteLocalMulticastAddress, + } + e.mu.Lock() defer e.mu.Unlock() - if !e.Enabled() { - return - } - if forwarding { // When transitioning into an IPv6 router, host-only state (NDP discovered // routers, discovered on-link prefixes, and auto-generated addresses) is // cleaned up/invalidated and NDP router solicitations are stopped. e.mu.ndp.stopSolicitingRouters() e.mu.ndp.cleanupState(true /* hostOnly */) - } else { - // When transitioning into an IPv6 host, NDP router solicitations are - // started. + + // As per RFC 4291 section 2.8: + // + // A router is required to recognize all addresses that a host is + // required to recognize, plus the following addresses as identifying + // itself: + // + // o The All-Routers multicast addresses defined in Section 2.7.1. + // + // As per RFC 4291 section 2.7.1, + // + // All Routers Addresses: FF01:0:0:0:0:0:0:2 + // FF02:0:0:0:0:0:0:2 + // FF05:0:0:0:0:0:0:2 + // + // The above multicast addresses identify the group of all IPv6 routers, + // within scope 1 (interface-local), 2 (link-local), or 5 (site-local). + for _, g := range allRoutersGroups { + if err := e.joinGroupLocked(g); err != nil { + // joinGroupLocked only returns an error if the group address is not a + // valid IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", g, err)) + } + } + + return + } + + for _, g := range allRoutersGroups { + switch err := e.leaveGroupLocked(g).(type) { + case nil: + case *tcpip.ErrBadLocalAddress: + // The endpoint may have already left the multicast group. + default: + panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err)) + } + } + + // When transitioning into an IPv6 host, NDP router solicitations are + // started if the endpoint is enabled. + // + // If the endpoint is not currently enabled, routers will be solicited when + // the endpoint becomes enabled (if it is still a host). + if e.Enabled() { e.mu.ndp.startSolicitingRouters() } } @@ -573,7 +616,7 @@ func (e *endpoint) disableLocked() { e.mu.ndp.cleanupState(false /* hostOnly */) // The endpoint may have already left the multicast group. - switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) { + switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress).(type) { case nil, *tcpip.ErrBadLocalAddress: default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) @@ -1979,9 +2022,9 @@ func (p *protocol) Forwarding() bool { // Returns true if the forwarding status was updated. func (p *protocol) setForwarding(v bool) bool { if v { - return atomic.SwapUint32(&p.forwarding, 1) == 0 + return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) } - return atomic.SwapUint32(&p.forwarding, 0) == 1 + return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) } // SetForwarding implements stack.ForwardingNetworkProtocol. diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index 538590baf..165b7d2d2 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -76,7 +76,7 @@ func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) // // Precondition: mld.ep.mu must be read locked. func (mld *mldState) SendLeave(groupAddress tcpip.Address) tcpip.Error { - _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + _, err := mld.writePacket(header.IPv6AllRoutersLinkLocalMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) return err } diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 1e611aca1..146b300f1 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -93,7 +93,7 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { if p, ok := e.Read(); !ok { t.Fatal("expected a done message to be sent") } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersLinkLocalMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) } } diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 536493f87..dd7f6a126 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -1703,7 +1703,7 @@ func (ndp *ndpState) startSolicitingRouters() { // the unspecified address if no address is assigned // to the sending interface. localAddr := header.IPv6Any - if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil { + if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersLinkLocalMulticastAddress, false); addressEndpoint != nil { localAddr = addressEndpoint.AddressWithPrefix().Address addressEndpoint.DecRef() } @@ -1730,7 +1730,7 @@ func (ndp *ndpState) startSolicitingRouters() { icmpData.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpData, Src: localAddr, - Dst: header.IPv6AllRoutersMulticastAddress, + Dst: header.IPv6AllRoutersLinkLocalMulticastAddress, })) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -1739,14 +1739,14 @@ func (ndp *ndpState) startSolicitingRouters() { }) sent := ndp.ep.stats.icmp.packetsSent - if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ + if err := addIPHeader(localAddr, header.IPv6AllRoutersLinkLocalMulticastAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, }, nil /* extensionHeaders */); err != nil { panic(fmt.Sprintf("failed to add IP header: %s", err)) } - if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.dropped.Increment() // Don't send any more messages if we had an error. remaining = 0 diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index ecd5003a7..2aa4e6d75 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -194,7 +194,7 @@ func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, c if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") } else { - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC) + validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6AddrSNMC) } // Should not send any more packets. @@ -606,7 +606,7 @@ func TestMGPLeaveGroup(t *testing.T) { validateLeave: func(t *testing.T, p channel.PacketInfo) { t.Helper() - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) + validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6MulticastAddr1) }, checkInitialGroups: checkInitialIPv6Groups, }, @@ -1014,7 +1014,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { t.Helper() - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr) + validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, addr) }, getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { t.Helper() diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 14124ae66..a869cce38 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -5204,13 +5204,13 @@ func TestRouterSolicitation(t *testing.T) { } // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), ) @@ -5362,7 +5362,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS()) } -- cgit v1.2.3