From f50d9a31e9e734a02e0191f6bc91b387bb21f9ab Mon Sep 17 00:00:00 2001
From: Ghanan Gowripalan <ghanan@google.com>
Date: Fri, 6 Mar 2020 11:32:13 -0800
Subject: Specify the source of outgoing NDP RS

If the NIC has a valid IPv6 address assigned, use it as the
source address for outgoing NDP Router Solicitation packets.

Test: stack_test.TestRouterSolicitation
PiperOrigin-RevId: 299398763
---
 pkg/tcpip/stack/ndp.go      | 29 +++++++++++++++++++++++++----
 pkg/tcpip/stack/ndp_test.go | 39 ++++++++++++++++++++++++++++++++-------
 2 files changed, 57 insertions(+), 11 deletions(-)

(limited to 'pkg/tcpip/stack')

diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index f651871ce..a9f4d5dad 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -1220,9 +1220,15 @@ func (ndp *ndpState) startSolicitingRouters() {
 	}
 
 	ndp.rtrSolicitTimer = time.AfterFunc(delay, func() {
-		// Send an RS message with the unspecified source address.
-		ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
-		r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+		// As per RFC 4861 section 4.1, the source of the RS is an address assigned
+		// to the sending interface, or the unspecified address if no address is
+		// assigned to the sending interface.
+		ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress)
+		if ref == nil {
+			ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
+		}
+		localAddr := ref.ep.ID().LocalAddress
+		r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
 		defer r.Release()
 
 		// Route should resolve immediately since
@@ -1234,10 +1240,25 @@ func (ndp *ndpState) startSolicitingRouters() {
 			log.Fatalf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())
 		}
 
-		payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize
+		// As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
+		// link-layer address option if the source address of the NDP RS is
+		// specified. This option MUST NOT be included if the source address is
+		// unspecified.
+		//
+		// TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+		// LinkEndpoint.LinkAddress) before reaching this point.
+		var optsSerializer header.NDPOptionsSerializer
+		if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) {
+			optsSerializer = header.NDPOptionsSerializer{
+				header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress),
+			}
+		}
+		payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
 		hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize)
 		pkt := header.ICMPv6(hdr.Prepend(payloadSize))
 		pkt.SetType(header.ICMPv6RouterSolicit)
+		rs := header.NDPRouterSolicit(pkt.NDPPayload())
+		rs.Options().Serialize(optsSerializer)
 		pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
 
 		sent := r.Stats().ICMP.V6PacketsSent
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 6e9306d09..98b1c807c 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -3384,6 +3384,10 @@ func TestRouterSolicitation(t *testing.T) {
 	tests := []struct {
 		name                        string
 		linkHeaderLen               uint16
+		linkAddr                    tcpip.LinkAddress
+		nicAddr                     tcpip.Address
+		expectedSrcAddr             tcpip.Address
+		expectedNDPOpts             []header.NDPOption
 		maxRtrSolicit               uint8
 		rtrSolicitInt               time.Duration
 		effectiveRtrSolicitInt      time.Duration
@@ -3392,6 +3396,7 @@ func TestRouterSolicitation(t *testing.T) {
 	}{
 		{
 			name:                        "Single RS with delay",
+			expectedSrcAddr:             header.IPv6Any,
 			maxRtrSolicit:               1,
 			rtrSolicitInt:               time.Second,
 			effectiveRtrSolicitInt:      time.Second,
@@ -3401,6 +3406,8 @@ func TestRouterSolicitation(t *testing.T) {
 		{
 			name:                        "Two RS with delay",
 			linkHeaderLen:               1,
+			nicAddr:                     llAddr1,
+			expectedSrcAddr:             llAddr1,
 			maxRtrSolicit:               2,
 			rtrSolicitInt:               time.Second,
 			effectiveRtrSolicitInt:      time.Second,
@@ -3408,8 +3415,14 @@ func TestRouterSolicitation(t *testing.T) {
 			effectiveMaxRtrSolicitDelay: 500 * time.Millisecond,
 		},
 		{
-			name:                        "Single RS without delay",
-			linkHeaderLen:               2,
+			name:            "Single RS without delay",
+			linkHeaderLen:   2,
+			linkAddr:        linkAddr1,
+			nicAddr:         llAddr1,
+			expectedSrcAddr: llAddr1,
+			expectedNDPOpts: []header.NDPOption{
+				header.NDPSourceLinkLayerAddressOption(linkAddr1),
+			},
 			maxRtrSolicit:               1,
 			rtrSolicitInt:               time.Second,
 			effectiveRtrSolicitInt:      time.Second,
@@ -3419,6 +3432,8 @@ func TestRouterSolicitation(t *testing.T) {
 		{
 			name:                        "Two RS without delay and invalid zero interval",
 			linkHeaderLen:               3,
+			linkAddr:                    linkAddr1,
+			expectedSrcAddr:             header.IPv6Any,
 			maxRtrSolicit:               2,
 			rtrSolicitInt:               0,
 			effectiveRtrSolicitInt:      4 * time.Second,
@@ -3427,6 +3442,8 @@ func TestRouterSolicitation(t *testing.T) {
 		},
 		{
 			name:                        "Three RS without delay",
+			linkAddr:                    linkAddr1,
+			expectedSrcAddr:             header.IPv6Any,
 			maxRtrSolicit:               3,
 			rtrSolicitInt:               500 * time.Millisecond,
 			effectiveRtrSolicitInt:      500 * time.Millisecond,
@@ -3435,6 +3452,8 @@ func TestRouterSolicitation(t *testing.T) {
 		},
 		{
 			name:                        "Two RS with invalid negative delay",
+			linkAddr:                    linkAddr1,
+			expectedSrcAddr:             header.IPv6Any,
 			maxRtrSolicit:               2,
 			rtrSolicitInt:               time.Second,
 			effectiveRtrSolicitInt:      time.Second,
@@ -3457,7 +3476,7 @@ func TestRouterSolicitation(t *testing.T) {
 			t.Run(test.name, func(t *testing.T) {
 				t.Parallel()
 				e := channelLinkWithHeaderLength{
-					Endpoint:     channel.New(int(test.maxRtrSolicit), 1280, linkAddr1),
+					Endpoint:     channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
 					headerLength: test.linkHeaderLen,
 				}
 				e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
@@ -3481,10 +3500,10 @@ func TestRouterSolicitation(t *testing.T) {
 
 					checker.IPv6(t,
 						p.Pkt.Header.View(),
-						checker.SrcAddr(header.IPv6Any),
+						checker.SrcAddr(test.expectedSrcAddr),
 						checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
 						checker.TTL(header.NDPHopLimit),
-						checker.NDPRS(),
+						checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
 					)
 
 					if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
@@ -3510,13 +3529,19 @@ func TestRouterSolicitation(t *testing.T) {
 					t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
 				}
 
-				// Make sure each RS got sent at the right
-				// times.
+				if addr := test.nicAddr; addr != "" {
+					if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
+						t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
+					}
+				}
+
+				// Make sure each RS is sent at the right time.
 				remaining := test.maxRtrSolicit
 				if remaining > 0 {
 					waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
 					remaining--
 				}
+
 				for ; remaining > 0; remaining-- {
 					waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout)
 					waitForPkt(defaultAsyncEventTimeout)
-- 
cgit v1.2.3