diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-03-16 11:07:02 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-03-16 11:09:26 -0700 |
commit | 68065d1ceb589b7ea1d3e4b3b067fb8772e30760 (patch) | |
tree | f3017f52fba725114b913cf893fcdcb6678415de /pkg | |
parent | ebd7c1b889e5d212f4a694d3addbada241936e8e (diff) |
Detect looped-back NDP DAD messages
...as per RFC 7527.
If a looped-back DAD message is received, do not fail DAD since our own
DAD message does not indicate that a neighbor has the address assigned.
Test: ndp_test.TestDADResolveLoopback
PiperOrigin-RevId: 363224288
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/checker/checker.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/header/ndp_options.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/header/ndp_test.go | 679 | ||||
-rw-r--r-- | pkg/tcpip/header/ndpoptionidentifier_string.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/network/internal/ip/duplicate_address_detection.go | 139 | ||||
-rw-r--r-- | pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go | 124 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 55 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 89 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/mld_test.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 72 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 79 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 18 |
14 files changed, 882 insertions, 479 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index fc622b246..fef065b05 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -1287,6 +1287,13 @@ func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) } + case header.NDPNonceOption: + gotOpt, ok := opt.(header.NDPNonceOption) + if !ok { + t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) + } else if diff := cmp.Diff(wantOpt.Nonce(), gotOpt.Nonce()); diff != "" { + t.Errorf("nonce mismatch (-want +got):\n%s", diff) + } default: t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) } diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 554242f0c..5deae465c 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -42,13 +42,17 @@ const ( // option, as per RFC 4861 section 4.6.2. NDPPrefixInformationType NDPOptionIdentifier = 3 + // NDPNonceOptionType is the type of the Nonce option, as per + // RFC 3971 section 5.3.2. + NDPNonceOptionType NDPOptionIdentifier = 14 + // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS // Server option, as per RFC 8106 section 5.1. NDPRecursiveDNSServerOptionType NDPOptionIdentifier = 25 // NDPDNSSearchListOptionType is the type of the DNS Search List option, // as per RFC 8106 section 5.2. - NDPDNSSearchListOptionType = 31 + NDPDNSSearchListOptionType NDPOptionIdentifier = 31 ) const ( @@ -231,6 +235,9 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { case NDPTargetLinkLayerAddressOptionType: return NDPTargetLinkLayerAddressOption(body), false, nil + case NDPNonceOptionType: + return NDPNonceOption(body), false, nil + case NDPPrefixInformationType: // Make sure the length of a Prefix Information option // body is ndpPrefixInformationLength, as per RFC 4861 @@ -416,6 +423,37 @@ func (b NDPOptionsSerializer) Length() int { return l } +// NDPNonceOption is the NDP Nonce Option as defined by RFC 3971 section 5.3.2. +// +// It is the first X bytes following the NDP option's Type and Length field +// where X is the value in Length multiplied by lengthByteUnits - 2 bytes. +type NDPNonceOption []byte + +// Type implements NDPOption. +func (o NDPNonceOption) Type() NDPOptionIdentifier { + return NDPNonceOptionType +} + +// Length implements NDPOption. +func (o NDPNonceOption) Length() int { + return len(o) +} + +// serializeInto implements NDPOption. +func (o NDPNonceOption) serializeInto(b []byte) int { + return copy(b, o) +} + +// String implements fmt.Stringer. +func (o NDPNonceOption) String() string { + return fmt.Sprintf("%T(%x)", o, []byte(o)) +} + +// Nonce returns the nonce value this option holds. +func (o NDPNonceOption) Nonce() []byte { + return []byte(o) +} + // NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option // as defined by RFC 4861 section 4.6.1. // diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index dc4591253..bfc73e245 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -16,6 +16,7 @@ package header import ( "bytes" + "encoding/binary" "errors" "fmt" "io" @@ -192,90 +193,6 @@ func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) { } } -// TestNDPSourceLinkLayerAddressOptionSerialize tests serializing a -// NDPSourceLinkLayerAddressOption. -func TestNDPSourceLinkLayerAddressOptionSerialize(t *testing.T) { - tests := []struct { - name string - buf []byte - expectedBuf []byte - addr tcpip.LinkAddress - }{ - { - "Ethernet", - make([]byte, 8), - []byte{1, 1, 1, 2, 3, 4, 5, 6}, - "\x01\x02\x03\x04\x05\x06", - }, - { - "Padding", - []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, - "\x01\x02\x03\x04\x05\x06\x07\x08", - }, - { - "Empty", - nil, - nil, - "", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := NDPOptions(test.buf) - serializer := NDPOptionsSerializer{ - NDPSourceLinkLayerAddressOption(test.addr), - } - if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { - t.Fatalf("got Length = %d, want = %d", got, want) - } - opts.Serialize(serializer) - if !bytes.Equal(test.buf, test.expectedBuf) { - t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - if len(test.expectedBuf) > 0 { - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType { - t.Fatalf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType) - } - sll := next.(NDPSourceLinkLayerAddressOption) - if got, want := []byte(sll), test.expectedBuf[2:]; !bytes.Equal(got, want) { - t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - - if got, want := sll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { - t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want) - } - } - - // Iterator should not return anything else. - next, done, err := it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } - }) - } -} - // TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the // Ethernet address from an NDPTargetLinkLayerAddressOption. func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { @@ -311,32 +228,309 @@ func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { } } -// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a -// NDPTargetLinkLayerAddressOption. -func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { +func TestOpts(t *testing.T) { + const optionHeaderLen = 2 + + checkNonce := func(expectedNonce []byte) func(*testing.T, NDPOption) { + return func(t *testing.T, opt NDPOption) { + if got := opt.Type(); got != NDPNonceOptionType { + t.Errorf("got Type() = %d, want = %d", got, NDPNonceOptionType) + } + nonce, ok := opt.(NDPNonceOption) + if !ok { + t.Fatalf("got nonce = %T, want = NDPNonceOption", opt) + } + if diff := cmp.Diff(expectedNonce, nonce.Nonce()); diff != "" { + t.Errorf("nonce mismatch (-want +got):\n%s", diff) + } + } + } + + checkTLL := func(expectedAddr tcpip.LinkAddress) func(*testing.T, NDPOption) { + return func(t *testing.T, opt NDPOption) { + if got := opt.Type(); got != NDPTargetLinkLayerAddressOptionType { + t.Errorf("got Type() = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + } + tll, ok := opt.(NDPTargetLinkLayerAddressOption) + if !ok { + t.Fatalf("got tll = %T, want = NDPTargetLinkLayerAddressOption", opt) + } + if got, want := tll.EthernetAddress(), expectedAddr; got != want { + t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want) + } + } + } + + checkSLL := func(expectedAddr tcpip.LinkAddress) func(*testing.T, NDPOption) { + return func(t *testing.T, opt NDPOption) { + if got := opt.Type(); got != NDPSourceLinkLayerAddressOptionType { + t.Errorf("got Type() = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType) + } + sll, ok := opt.(NDPSourceLinkLayerAddressOption) + if !ok { + t.Fatalf("got sll = %T, want = NDPSourceLinkLayerAddressOption", opt) + } + if got, want := sll.EthernetAddress(), expectedAddr; got != want { + t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want) + } + } + } + + const validLifetimeSeconds = 16909060 + const address = tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18") + + expectedRDNSSBytes := [...]byte{ + // Type, Length + 25, 3, + + // Reserved + 0, 0, + + // Lifetime + 1, 2, 4, 8, + + // Address + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + binary.BigEndian.PutUint32(expectedRDNSSBytes[4:], validLifetimeSeconds) + if n := copy(expectedRDNSSBytes[8:], address); n != IPv6AddressSize { + t.Fatalf("got copy(...) = %d, want = %d", n, IPv6AddressSize) + } + // Update reserved fields to non zero values to make sure serializing sets + // them to zero. + rdnssBytes := expectedRDNSSBytes + rdnssBytes[1] = 1 + rdnssBytes[2] = 2 + + const searchListPaddingBytes = 3 + const domainName = "abc.abcd.e" + expectedSearchListBytes := [...]byte{ + // Type, Length + 31, 3, + + // Reserved + 0, 0, + + // Lifetime + 1, 0, 0, 0, + + // Domain names + 3, 'a', 'b', 'c', + 4, 'a', 'b', 'c', 'd', + 1, 'e', + 0, + 0, 0, 0, 0, + } + binary.BigEndian.PutUint32(expectedSearchListBytes[4:], validLifetimeSeconds) + // Update reserved fields to non zero values to make sure serializing sets + // them to zero. + searchListBytes := expectedSearchListBytes + searchListBytes[2] = 1 + searchListBytes[3] = 2 + + const prefixLength = 43 + const onLinkFlag = false + const slaacFlag = true + const preferredLifetimeSeconds = 84281096 + const onLinkFlagBit = 7 + const slaacFlagBit = 6 + boolToByte := func(v bool) byte { + if v { + return 1 + } + return 0 + } + flags := boolToByte(onLinkFlag)<<onLinkFlagBit | boolToByte(slaacFlag)<<slaacFlagBit + expectedPrefixInformationBytes := [...]byte{ + // Type, Length + 3, 4, + + prefixLength, flags, + + // Valid Lifetime + 1, 2, 3, 4, + + // Preferred Lifetime + 5, 6, 7, 8, + + // Reserved2 + 0, 0, 0, 0, + + // Address + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + binary.BigEndian.PutUint32(expectedPrefixInformationBytes[4:], validLifetimeSeconds) + binary.BigEndian.PutUint32(expectedPrefixInformationBytes[8:], preferredLifetimeSeconds) + if n := copy(expectedPrefixInformationBytes[16:], address); n != IPv6AddressSize { + t.Fatalf("got copy(...) = %d, want = %d", n, IPv6AddressSize) + } + // Update reserved fields to non zero values to make sure serializing sets + // them to zero. + prefixInformationBytes := expectedPrefixInformationBytes + prefixInformationBytes[3] |= (1 << slaacFlagBit) - 1 + binary.BigEndian.PutUint32(prefixInformationBytes[12:], validLifetimeSeconds+1) tests := []struct { name string buf []byte + opt NDPOption expectedBuf []byte - addr tcpip.LinkAddress + check func(*testing.T, NDPOption) }{ { - "Ethernet", - make([]byte, 8), - []byte{2, 1, 1, 2, 3, 4, 5, 6}, - "\x01\x02\x03\x04\x05\x06", + name: "Nonce", + buf: make([]byte, 8), + opt: NDPNonceOption([]byte{1, 2, 3, 4, 5, 6}), + expectedBuf: []byte{14, 1, 1, 2, 3, 4, 5, 6}, + check: checkNonce([]byte{1, 2, 3, 4, 5, 6}), + }, + { + name: "Nonce with padding", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1}, + opt: NDPNonceOption([]byte{1, 2, 3, 4, 5}), + expectedBuf: []byte{14, 1, 1, 2, 3, 4, 5, 0}, + check: checkNonce([]byte{1, 2, 3, 4, 5, 0}), + }, + + { + name: "TLL Ethernet", + buf: make([]byte, 8), + opt: NDPTargetLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06"), + expectedBuf: []byte{2, 1, 1, 2, 3, 4, 5, 6}, + check: checkTLL("\x01\x02\x03\x04\x05\x06"), + }, + { + name: "TLL Padding", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + opt: NDPTargetLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06\x07\x08"), + expectedBuf: []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, + check: checkTLL("\x01\x02\x03\x04\x05\x06"), + }, + { + name: "TLL Empty", + buf: nil, + opt: NDPTargetLinkLayerAddressOption(""), + expectedBuf: nil, + }, + + { + name: "SLL Ethernet", + buf: make([]byte, 8), + opt: NDPSourceLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06"), + expectedBuf: []byte{1, 1, 1, 2, 3, 4, 5, 6}, + check: checkSLL("\x01\x02\x03\x04\x05\x06"), + }, + { + name: "SLL Padding", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + opt: NDPSourceLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06\x07\x08"), + expectedBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, + check: checkSLL("\x01\x02\x03\x04\x05\x06"), + }, + { + name: "SLL Empty", + buf: nil, + opt: NDPSourceLinkLayerAddressOption(""), + expectedBuf: nil, + }, + + { + name: "RDNSS", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + // NDPRecursiveDNSServer holds the option after the header bytes. + opt: NDPRecursiveDNSServer(rdnssBytes[optionHeaderLen:]), + expectedBuf: expectedRDNSSBytes[:], + check: func(t *testing.T, opt NDPOption) { + if got := opt.Type(); got != NDPRecursiveDNSServerOptionType { + t.Errorf("got Type() = %d, want = %d", got, NDPRecursiveDNSServerOptionType) + } + rdnss, ok := opt.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("got opt = %T, want = NDPRecursiveDNSServer", opt) + } + if got, want := rdnss.Length(), len(expectedRDNSSBytes[optionHeaderLen:]); got != want { + t.Errorf("got Length() = %d, want = %d", got, want) + } + if got, want := rdnss.Lifetime(), validLifetimeSeconds*time.Second; got != want { + t.Errorf("got Lifetime() = %s, want = %s", got, want) + } + if addrs, err := rdnss.Addresses(); err != nil { + t.Errorf("Addresses(): %s", err) + } else if diff := cmp.Diff([]tcpip.Address{address}, addrs); diff != "" { + t.Errorf("mismatched addresses (-want +got):\n%s", diff) + } + }, }, + { - "Padding", - []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, - "\x01\x02\x03\x04\x05\x06\x07\x08", + name: "Search list", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + opt: NDPDNSSearchList(searchListBytes[optionHeaderLen:]), + expectedBuf: expectedSearchListBytes[:], + check: func(t *testing.T, opt NDPOption) { + if got := opt.Type(); got != NDPDNSSearchListOptionType { + t.Errorf("got Type() = %d, want = %d", got, NDPDNSSearchListOptionType) + } + + dnssl, ok := opt.(NDPDNSSearchList) + if !ok { + t.Fatalf("got opt = %T, want = NDPDNSSearchList", opt) + } + if got, want := dnssl.Length(), len(expectedRDNSSBytes[optionHeaderLen:]); got != want { + t.Errorf("got Length() = %d, want = %d", got, want) + } + if got, want := dnssl.Lifetime(), validLifetimeSeconds*time.Second; got != want { + t.Errorf("got Lifetime() = %s, want = %s", got, want) + } + + if domainNames, err := dnssl.DomainNames(); err != nil { + t.Errorf("DomainNames(): %s", err) + } else if diff := cmp.Diff([]string{domainName}, domainNames); diff != "" { + t.Errorf("domain names mismatch (-want +got):\n%s", diff) + } + }, }, + { - "Empty", - nil, - nil, - "", + name: "Prefix Information", + buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + // NDPPrefixInformation holds the option after the header bytes. + opt: NDPPrefixInformation(prefixInformationBytes[optionHeaderLen:]), + expectedBuf: expectedPrefixInformationBytes[:], + check: func(t *testing.T, opt NDPOption) { + if got := opt.Type(); got != NDPPrefixInformationType { + t.Errorf("got Type() = %d, want = %d", got, NDPPrefixInformationType) + } + + pi, ok := opt.(NDPPrefixInformation) + if !ok { + t.Fatalf("got opt = %T, want = NDPPrefixInformation", opt) + } + + if got, want := pi.Length(), len(expectedPrefixInformationBytes[optionHeaderLen:]); got != want { + t.Errorf("got Length() = %d, want = %d", got, want) + } + if got := pi.PrefixLength(); got != prefixLength { + t.Errorf("got PrefixLength() = %d, want = %d", got, prefixLength) + } + if got := pi.OnLinkFlag(); got != onLinkFlag { + t.Errorf("got OnLinkFlag() = %t, want = %t", got, onLinkFlag) + } + if got := pi.AutonomousAddressConfigurationFlag(); got != slaacFlag { + t.Errorf("got AutonomousAddressConfigurationFlag() = %t, want = %t", got, slaacFlag) + } + if got, want := pi.ValidLifetime(), validLifetimeSeconds*time.Second; got != want { + t.Errorf("got ValidLifetime() = %s, want = %s", got, want) + } + if got, want := pi.PreferredLifetime(), preferredLifetimeSeconds*time.Second; got != want { + t.Errorf("got PreferredLifetime() = %s, want = %s", got, want) + } + if got := pi.Prefix(); got != address { + t.Errorf("got Prefix() = %s, want = %s", got, address) + } + }, }, } @@ -344,230 +538,47 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { t.Run(test.name, func(t *testing.T) { opts := NDPOptions(test.buf) serializer := NDPOptionsSerializer{ - NDPTargetLinkLayerAddressOption(test.addr), + test.opt, } if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { - t.Fatalf("got Length = %d, want = %d", got, want) + t.Fatalf("got Length() = %d, want = %d", got, want) } opts.Serialize(serializer) - if !bytes.Equal(test.buf, test.expectedBuf) { - t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) + if diff := cmp.Diff(test.expectedBuf, test.buf); diff != "" { + t.Fatalf("serialized buffer mismatch (-want +got):\n%s", diff) } it, err := opts.Iter(true) if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + t.Fatalf("got Iter(true) = (_, %s), want = (_, nil)", err) } if len(test.expectedBuf) > 0 { next, done, err := it.Next() if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + t.Fatalf("got Next() = (_, _, %s), want = (_, _, nil)", err) } if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { - t.Fatalf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) - } - tll := next.(NDPTargetLinkLayerAddressOption) - if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) { - t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) - } - - if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { - t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want) + t.Fatal("got Next() = (_, true, _), want = (_, false, _)") } + test.check(t, next) } // Iterator should not return anything else. next, done, err := it.Next() if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + t.Errorf("got Next() = (_, _, %s), want = (_, _, nil)", err) } if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") + t.Error("got Next() = (_, false, _), want = (_, true, _)") } if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + t.Errorf("got Next() = (%x, _, _), want = (nil, _, _)", next) } }) } } -// TestNDPPrefixInformationOption tests the field getters and serialization of a -// NDPPrefixInformation. -func TestNDPPrefixInformationOption(t *testing.T) { - b := []byte{ - 43, 127, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 5, 5, 5, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPPrefixInformation(b), - } - opts.Serialize(serializer) - expectedBuf := []byte{ - 3, 4, 43, 64, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - 21, 22, 23, 24, - } - if !bytes.Equal(targetBuf, expectedBuf) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPPrefixInformationType { - t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) - } - - pi := next.(NDPPrefixInformation) - - if got := pi.Type(); got != 3 { - t.Errorf("got Type = %d, want = 3", got) - } - - if got := pi.Length(); got != 30 { - t.Errorf("got Length = %d, want = 30", got) - } - - if got := pi.PrefixLength(); got != 43 { - t.Errorf("got PrefixLength = %d, want = 43", got) - } - - if pi.OnLinkFlag() { - t.Error("got OnLinkFlag = true, want = false") - } - - if !pi.AutonomousAddressConfigurationFlag() { - t.Error("got AutonomousAddressConfigurationFlag = false, want = true") - } - - if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want { - t.Errorf("got ValidLifetime = %d, want = %d", got, want) - } - - if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want { - t.Errorf("got PreferredLifetime = %d, want = %d", got, want) - } - - if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want { - t.Errorf("got Prefix = %s, want = %s", got, want) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - -func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) { - b := []byte{ - 9, 8, - 1, 2, 4, 8, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - } - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - expected := []byte{ - 25, 3, 0, 0, - 1, 2, 4, 8, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - } - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPRecursiveDNSServer(b), - } - if got, want := opts.Serialize(serializer), len(expected); got != want { - t.Errorf("got Serialize = %d, want = %d", got, want) - } - if !bytes.Equal(targetBuf, expected) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPRecursiveDNSServerOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) - } - - opt, ok := next.(NDPRecursiveDNSServer) - if !ok { - t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) - } - if got := opt.Type(); got != 25 { - t.Errorf("got Type = %d, want = 31", got) - } - if got := opt.Length(); got != 22 { - t.Errorf("got Length = %d, want = 22", got) - } - if got, want := opt.Lifetime(), 16909320*time.Second; got != want { - t.Errorf("got Lifetime = %s, want = %s", got, want) - } - want := []tcpip.Address{ - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - } - addrs, err := opt.Addresses() - if err != nil { - t.Errorf("opt.Addresses() = %s", err) - } - if diff := cmp.Diff(addrs, want); diff != "" { - t.Errorf("mismatched addresses (-want +got):\n%s", diff) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - func TestNDPRecursiveDNSServerOption(t *testing.T) { tests := []struct { name string @@ -1060,86 +1071,6 @@ func TestNDPSearchListOptionDomainNameLabelInvalidSymbols(t *testing.T) { } } -func TestNDPDNSSearchListOptionSerialize(t *testing.T) { - b := []byte{ - 9, 8, - 1, 0, 0, 0, - 3, 'a', 'b', 'c', - 4, 'a', 'b', 'c', 'd', - 1, 'e', - 0, - } - targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} - expected := []byte{ - 31, 3, 0, 0, - 1, 0, 0, 0, - 3, 'a', 'b', 'c', - 4, 'a', 'b', 'c', 'd', - 1, 'e', - 0, - 0, 0, 0, 0, - } - opts := NDPOptions(targetBuf) - serializer := NDPOptionsSerializer{ - NDPDNSSearchList(b), - } - if got, want := opts.Serialize(serializer), len(expected); got != want { - t.Errorf("got Serialize = %d, want = %d", got, want) - } - if !bytes.Equal(targetBuf, expected) { - t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) - } - - it, err := opts.Iter(true) - if err != nil { - t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) - } - - next, done, err := it.Next() - if err != nil { - t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if done { - t.Fatal("got Next = (_, true, _), want = (_, false, _)") - } - if got := next.Type(); got != NDPDNSSearchListOptionType { - t.Errorf("got Type = %d, want = %d", got, NDPDNSSearchListOptionType) - } - - opt, ok := next.(NDPDNSSearchList) - if !ok { - t.Fatalf("next (type = %T) cannot be casted to an NDPDNSSearchList", next) - } - if got := opt.Type(); got != 31 { - t.Errorf("got Type = %d, want = 31", got) - } - if got := opt.Length(); got != 22 { - t.Errorf("got Length = %d, want = 22", got) - } - if got, want := opt.Lifetime(), 16777216*time.Second; got != want { - t.Errorf("got Lifetime = %s, want = %s", got, want) - } - domainNames, err := opt.DomainNames() - if err != nil { - t.Errorf("opt.DomainNames() = %s", err) - } - if diff := cmp.Diff(domainNames, []string{"abc.abcd.e"}); diff != "" { - t.Errorf("domain names mismatch (-want +got):\n%s", diff) - } - - // Iterator should not return anything else. - next, done, err = it.Next() - if err != nil { - t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) - } - if !done { - t.Error("got Next = (_, false, _), want = (_, true, _)") - } - if next != nil { - t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) - } -} - // TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions // the iterator was returned for is malformed. func TestNDPOptionsIterCheck(t *testing.T) { diff --git a/pkg/tcpip/header/ndpoptionidentifier_string.go b/pkg/tcpip/header/ndpoptionidentifier_string.go index 6fe9a336b..83f94a730 100644 --- a/pkg/tcpip/header/ndpoptionidentifier_string.go +++ b/pkg/tcpip/header/ndpoptionidentifier_string.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Code generated by "stringer -type NDPOptionIdentifier ."; DO NOT EDIT. +// Code generated by "stringer -type NDPOptionIdentifier"; DO NOT EDIT. package header @@ -25,12 +25,16 @@ func _() { _ = x[NDPSourceLinkLayerAddressOptionType-1] _ = x[NDPTargetLinkLayerAddressOptionType-2] _ = x[NDPPrefixInformationType-3] + _ = x[NDPNonceOptionType-14] _ = x[NDPRecursiveDNSServerOptionType-25] + _ = x[NDPDNSSearchListOptionType-31] } const ( _NDPOptionIdentifier_name_0 = "NDPSourceLinkLayerAddressOptionTypeNDPTargetLinkLayerAddressOptionTypeNDPPrefixInformationType" - _NDPOptionIdentifier_name_1 = "NDPRecursiveDNSServerOptionType" + _NDPOptionIdentifier_name_1 = "NDPNonceOptionType" + _NDPOptionIdentifier_name_2 = "NDPRecursiveDNSServerOptionType" + _NDPOptionIdentifier_name_3 = "NDPDNSSearchListOptionType" ) var ( @@ -42,8 +46,12 @@ func (i NDPOptionIdentifier) String() string { case 1 <= i && i <= 3: i -= 1 return _NDPOptionIdentifier_name_0[_NDPOptionIdentifier_index_0[i]:_NDPOptionIdentifier_index_0[i+1]] - case i == 25: + case i == 14: return _NDPOptionIdentifier_name_1 + case i == 25: + return _NDPOptionIdentifier_name_2 + case i == 31: + return _NDPOptionIdentifier_name_3 default: return "NDPOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")" } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 43a4b7cac..7ae38d684 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -38,6 +38,7 @@ const ( var _ stack.DuplicateAddressDetector = (*endpoint)(nil) var _ stack.LinkAddressResolver = (*endpoint)(nil) +var _ ip.DADProtocol = (*endpoint)(nil) // ARP endpoints need to implement stack.NetworkEndpoint because the stack // considers the layer above the link-layer a network layer; the only @@ -82,7 +83,8 @@ func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } -func (e *endpoint) SendDADMessage(addr tcpip.Address) tcpip.Error { +// SendDADMessage implements ip.DADProtocol. +func (e *endpoint) SendDADMessage(addr tcpip.Address, _ []byte) tcpip.Error { return e.sendARPRequest(header.IPv4Any, addr, header.EthernetBroadcastAddress) } @@ -284,9 +286,12 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran e.mu.Lock() e.mu.dad.Init(&e.mu, p.options.DADConfigs, ip.DADOptions{ - Clock: p.stack.Clock(), - Protocol: e, - NICID: nic.ID(), + Clock: p.stack.Clock(), + SecureRNG: p.stack.SecureRNG(), + // ARP does not support sending nonce values. + NonceSize: 0, + Protocol: e, + NICID: nic.ID(), }) e.mu.Unlock() diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go index 0053646ee..eed49f5d2 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go @@ -16,14 +16,27 @@ package ip import ( + "bytes" "fmt" + "io" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +type extendRequest int + +const ( + notRequested extendRequest = iota + requested + extended +) + type dadState struct { + nonce []byte + extendRequest extendRequest + done *bool timer tcpip.Timer @@ -33,14 +46,17 @@ type dadState struct { // DADProtocol is a protocol whose core state machine can be represented by DAD. type DADProtocol interface { // SendDADMessage attempts to send a DAD probe message. - SendDADMessage(tcpip.Address) tcpip.Error + SendDADMessage(tcpip.Address, []byte) tcpip.Error } // DADOptions holds options for DAD. type DADOptions struct { - Clock tcpip.Clock - Protocol DADProtocol - NICID tcpip.NICID + Clock tcpip.Clock + SecureRNG io.Reader + NonceSize uint8 + ExtendDADTransmits uint8 + Protocol DADProtocol + NICID tcpip.NICID } // DAD performs duplicate address detection for addresses. @@ -63,6 +79,10 @@ func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts panic("attempted to initialize DAD state twice") } + if opts.NonceSize != 0 && opts.ExtendDADTransmits == 0 { + panic(fmt.Sprintf("given a non-zero value for NonceSize (%d) but zero for ExtendDADTransmits", opts.NonceSize)) + } + *d = DAD{ opts: opts, configs: configs, @@ -96,10 +116,55 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet s = dadState{ done: &done, timer: d.opts.Clock.AfterFunc(0, func() { - var err tcpip.Error dadDone := remaining == 0 + + nonce, earlyReturn := func() ([]byte, bool) { + d.protocolMU.Lock() + defer d.protocolMU.Unlock() + + if done { + return nil, true + } + + s, ok := d.addresses[addr] + if !ok { + panic(fmt.Sprintf("dad: timer fired but missing state for %s on NIC(%d)", addr, d.opts.NICID)) + } + + // As per RFC 7527 section 4 + // + // If any probe is looped back within RetransTimer milliseconds + // after having sent DupAddrDetectTransmits NS(DAD) messages, the + // interface continues with another MAX_MULTICAST_SOLICIT number of + // NS(DAD) messages transmitted RetransTimer milliseconds apart. + if dadDone && s.extendRequest == requested { + dadDone = false + remaining = d.opts.ExtendDADTransmits + s.extendRequest = extended + } + + if !dadDone && d.opts.NonceSize != 0 { + if s.nonce == nil { + s.nonce = make([]byte, d.opts.NonceSize) + } + + if n, err := io.ReadFull(d.opts.SecureRNG, s.nonce); err != nil { + panic(fmt.Sprintf("SecureRNG.Read(...): %s", err)) + } else if n != len(s.nonce) { + panic(fmt.Sprintf("expected to read %d bytes from secure RNG, only read %d bytes", len(s.nonce), n)) + } + } + + d.addresses[addr] = s + return s.nonce, false + }() + if earlyReturn { + return + } + + var err tcpip.Error if !dadDone { - err = d.opts.Protocol.SendDADMessage(addr) + err = d.opts.Protocol.SendDADMessage(addr, nonce) } d.protocolMU.Lock() @@ -142,6 +207,68 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet return ret } +// ExtendIfNonceEqualLockedDisposition enumerates the possible results from +// ExtendIfNonceEqualLocked. +type ExtendIfNonceEqualLockedDisposition int + +const ( + // Extended indicates that the DAD process was extended. + Extended ExtendIfNonceEqualLockedDisposition = iota + + // AlreadyExtended indicates that the DAD process was already extended. + AlreadyExtended + + // NoDADStateFound indicates that DAD state was not found for the address. + NoDADStateFound + + // NonceDisabled indicates that nonce values are not sent with DAD messages. + NonceDisabled + + // NonceNotEqual indicates that the nonce value passed and the nonce in the + // last send DAD message are not equal. + NonceNotEqual +) + +// ExtendIfNonceEqualLocked extends the DAD process if the provided nonce is the +// same as the nonce sent in the last DAD message. +// +// Precondition: d.protocolMU must be locked. +func (d *DAD) ExtendIfNonceEqualLocked(addr tcpip.Address, nonce []byte) ExtendIfNonceEqualLockedDisposition { + s, ok := d.addresses[addr] + if !ok { + return NoDADStateFound + } + + if d.opts.NonceSize == 0 { + return NonceDisabled + } + + if s.extendRequest != notRequested { + return AlreadyExtended + } + + // As per RFC 7527 section 4 + // + // If any probe is looped back within RetransTimer milliseconds after having + // sent DupAddrDetectTransmits NS(DAD) messages, the interface continues + // with another MAX_MULTICAST_SOLICIT number of NS(DAD) messages transmitted + // RetransTimer milliseconds apart. + // + // If a DAD message has already been sent and the nonce value we observed is + // the same as the nonce value we last sent, then we assume our probe was + // looped back and request an extension to the DAD process. + // + // Note, the first DAD message is sent asynchronously so we need to make sure + // that we sent a DAD message by checking if we have a nonce value set. + if s.nonce != nil && bytes.Equal(s.nonce, nonce) { + s.extendRequest = requested + d.addresses[addr] = s + return Extended + } + + return NonceNotEqual +} + // StopLocked stops a currently running DAD process. // // Precondition: d.protocolMU must be locked. diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go index e00aa4678..a22b712c6 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go @@ -15,6 +15,7 @@ package ip_test import ( + "bytes" "testing" "time" @@ -32,8 +33,8 @@ type mockDADProtocol struct { mu struct { sync.Mutex - dad ip.DAD - sendCount map[tcpip.Address]int + dad ip.DAD + sentNonces map[tcpip.Address][][]byte } } @@ -48,26 +49,30 @@ func (m *mockDADProtocol) init(t *testing.T, c stack.DADConfigurations, opts ip. } func (m *mockDADProtocol) initLocked() { - m.mu.sendCount = make(map[tcpip.Address]int) + m.mu.sentNonces = make(map[tcpip.Address][][]byte) } -func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address) tcpip.Error { +func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error { m.mu.Lock() defer m.mu.Unlock() - m.mu.sendCount[addr]++ + m.mu.sentNonces[addr] = append(m.mu.sentNonces[addr], nonce) return nil } func (m *mockDADProtocol) check(addrs []tcpip.Address) string { - m.mu.Lock() - defer m.mu.Unlock() - - sendCount := make(map[tcpip.Address]int) + sentNonces := make(map[tcpip.Address][][]byte) for _, a := range addrs { - sendCount[a]++ + sentNonces[a] = append(sentNonces[a], nil) } - diff := cmp.Diff(sendCount, m.mu.sendCount) + return m.checkWithNonce(sentNonces) +} + +func (m *mockDADProtocol) checkWithNonce(expectedSentNonces map[tcpip.Address][][]byte) string { + m.mu.Lock() + defer m.mu.Unlock() + + diff := cmp.Diff(expectedSentNonces, m.mu.sentNonces) m.initLocked() return diff } @@ -84,6 +89,12 @@ func (m *mockDADProtocol) stop(addr tcpip.Address, reason stack.DADResult) { m.mu.dad.StopLocked(addr, reason) } +func (m *mockDADProtocol) extendIfNonceEqual(addr tcpip.Address, nonce []byte) ip.ExtendIfNonceEqualLockedDisposition { + m.mu.Lock() + defer m.mu.Unlock() + return m.mu.dad.ExtendIfNonceEqualLocked(addr, nonce) +} + func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) { m.mu.Lock() defer m.mu.Unlock() @@ -277,3 +288,94 @@ func TestDADStop(t *testing.T) { default: } } + +func TestNonce(t *testing.T) { + const ( + nonceSize = 2 + + extendRequestAttempts = 2 + + dupAddrDetectTransmits = 2 + extendTransmits = 5 + ) + + var secureRNGBytes [nonceSize * (dupAddrDetectTransmits + extendTransmits)]byte + for i := range secureRNGBytes { + secureRNGBytes[i] = byte(i) + } + + tests := []struct { + name string + mockedReceivedNonce []byte + expectedResults [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition + expectedTransmits int + }{ + { + name: "not matching", + mockedReceivedNonce: []byte{0, 0}, + expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.NonceNotEqual, ip.NonceNotEqual}, + expectedTransmits: dupAddrDetectTransmits, + }, + { + name: "matching nonce", + mockedReceivedNonce: secureRNGBytes[:nonceSize], + expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.Extended, ip.AlreadyExtended}, + expectedTransmits: dupAddrDetectTransmits + extendTransmits, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var dad mockDADProtocol + clock := faketime.NewManualClock() + dadConfigs := stack.DADConfigurations{ + DupAddrDetectTransmits: dupAddrDetectTransmits, + RetransmitTimer: time.Second, + } + + var secureRNG bytes.Reader + secureRNG.Reset(secureRNGBytes[:]) + dad.init(t, dadConfigs, ip.DADOptions{ + Clock: clock, + SecureRNG: &secureRNG, + NonceSize: nonceSize, + ExtendDADTransmits: extendTransmits, + }) + + ch := make(chan dadResult, 1) + if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { + t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) + } + + clock.Advance(0) + for i, want := range test.expectedResults { + if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want { + t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want) + } + } + + for i := 0; i < test.expectedTransmits; i++ { + if diff := dad.checkWithNonce(map[tcpip.Address][][]byte{ + addr1: { + secureRNGBytes[nonceSize*i:][:nonceSize], + }, + }); diff != "" { + t.Errorf("(i=%d) dad check mismatch (-want +got):\n%s", i, diff) + } + + clock.Advance(dadConfigs.RetransmitTimer) + } + + if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" { + t.Errorf("dad result mismatch (-want +got):\n%s", diff) + } + + // Should not have anymore updates. + select { + case r := <-ch: + t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r) + default: + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 8059e0690..2afa856dc 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -369,6 +369,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } + var it header.NDPOptionIterator + { + var err error + it, err = ns.Options().Iter(false /* check */) + if err != nil { + // Options are not valid as per the wire format, silently drop the + // packet. + received.invalid.Increment() + return + } + } + if e.hasTentativeAddr(targetAddr) { // If the target address is tentative and the source of the packet is a // unicast (specified) address, then the source of the packet is @@ -382,6 +394,22 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // stack know so it can handle such a scenario and do nothing further with // the NS. if srcAddr == header.IPv6Any { + var nonce []byte + for { + opt, done, err := it.Next() + if err != nil { + received.invalid.Increment() + return + } + if done { + break + } + if n, ok := opt.(header.NDPNonceOption); ok { + nonce = n.Nonce() + break + } + } + // Since this is a DAD message we know the sender does not actually hold // the target address so there is no "holder". var holderLinkAddress tcpip.LinkAddress @@ -397,7 +425,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - switch err := e.dupTentativeAddrDetected(targetAddr, holderLinkAddress); err.(type) { + switch err := e.dupTentativeAddrDetected(targetAddr, holderLinkAddress, nonce); err.(type) { case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: default: panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) @@ -418,21 +446,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - var sourceLinkAddr tcpip.LinkAddress - { - it, err := ns.Options().Iter(false /* check */) - if err != nil { - // Options are not valid as per the wire format, silently drop the - // packet. - received.invalid.Increment() - return - } - - sourceLinkAddr, ok = getSourceLinkAddr(it) - if !ok { - received.invalid.Increment() - return - } + sourceLinkAddr, ok := getSourceLinkAddr(it) + if !ok { + received.invalid.Increment() + return } // As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST @@ -586,6 +603,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r e.dad.mu.Unlock() if e.hasTentativeAddr(targetAddr) { + // We only send a nonce value in DAD messages to check for loopedback + // messages so we use the empty nonce value here. + var nonce []byte + // We just got an NA from a node that owns an address we are performing // DAD on, implying the address is not unique. In this case we let the // stack know so it can handle such a scenario and do nothing furthur with @@ -602,7 +623,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - switch err := e.dupTentativeAddrDetected(targetAddr, targetLinkAddr); err.(type) { + switch err := e.dupTentativeAddrDetected(targetAddr, targetLinkAddr, nonce); err.(type) { case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: return default: diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 46b6cc41a..350493958 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -348,7 +348,7 @@ func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool { // dupTentativeAddrDetected removes the tentative address if it exists. If the // address was generated via SLAAC, an attempt is made to generate a new // address. -func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr tcpip.LinkAddress) tcpip.Error { +func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr tcpip.LinkAddress, nonce []byte) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -361,27 +361,48 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t return &tcpip.ErrInvalidEndpointState{} } - // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an - // attempt will be made to generate a new address for it. - if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, &stack.DADDupAddrDetected{HolderLinkAddress: holderLinkAddr}); err != nil { - return err - } + switch result := e.mu.ndp.dad.ExtendIfNonceEqualLocked(addr, nonce); result { + case ip.Extended: + // The nonce we got back was the same we sent so we know the message + // indicating a duplicate address was likely ours so do not consider + // the address duplicate here. + return nil + case ip.AlreadyExtended: + // See Extended. + // + // Our DAD message was looped back already. + return nil + case ip.NoDADStateFound: + panic(fmt.Sprintf("expected DAD state for tentative address %s", addr)) + case ip.NonceDisabled: + // If nonce is disabled then we have no way to know if the packet was + // looped-back so we have to assume it indicates a duplicate address. + fallthrough + case ip.NonceNotEqual: + // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an + // attempt will be made to generate a new address for it. + if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, &stack.DADDupAddrDetected{HolderLinkAddress: holderLinkAddr}); err != nil { + return err + } - prefix := addressEndpoint.Subnet() + prefix := addressEndpoint.Subnet() - switch t := addressEndpoint.ConfigType(); t { - case stack.AddressConfigStatic: - case stack.AddressConfigSlaac: - e.mu.ndp.regenerateSLAACAddr(prefix) - case stack.AddressConfigSlaacTemp: - // Do not reset the generation attempts counter for the prefix as the - // temporary address is being regenerated in response to a DAD conflict. - e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) + switch t := addressEndpoint.ConfigType(); t { + case stack.AddressConfigStatic: + case stack.AddressConfigSlaac: + e.mu.ndp.regenerateSLAACAddr(prefix) + case stack.AddressConfigSlaacTemp: + // Do not reset the generation attempts counter for the prefix as the + // temporary address is being regenerated in response to a DAD conflict. + e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) + default: + panic(fmt.Sprintf("unrecognized address config type = %d", t)) + } + + return nil default: - panic(fmt.Sprintf("unrecognized address config type = %d", t)) + panic(fmt.Sprintf("unhandled result = %d", result)) } - - return nil } // transitionForwarding transitions the endpoint's forwarding status to @@ -1797,16 +1818,36 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran dispatcher: dispatcher, protocol: p, } + + // NDP options must be 8 octet aligned and the first 2 bytes are used for + // the type and length fields leaving 6 octets as the minimum size for a + // nonce option without padding. + const nonceSize = 6 + + // As per RFC 7527 section 4.1, + // + // If any probe is looped back within RetransTimer milliseconds after + // having sent DupAddrDetectTransmits NS(DAD) messages, the interface + // continues with another MAX_MULTICAST_SOLICIT number of NS(DAD) + // messages transmitted RetransTimer milliseconds apart. + // + // Value taken from RFC 4861 section 10. + const maxMulticastSolicit = 3 + dadOptions := ip.DADOptions{ + Clock: p.stack.Clock(), + SecureRNG: p.stack.SecureRNG(), + NonceSize: nonceSize, + ExtendDADTransmits: maxMulticastSolicit, + Protocol: &e.mu.ndp, + NICID: nic.ID(), + } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) - e.mu.ndp.init(e) + e.mu.ndp.init(e, dadOptions) e.mu.mld.init(e) e.dad.mu.Lock() - e.dad.mu.dad.Init(&e.dad.mu, p.options.DADConfigs, ip.DADOptions{ - Clock: p.stack.Clock(), - Protocol: &e.mu.ndp, - NICID: nic.ID(), - }) + e.dad.mu.dad.Init(&e.dad.mu, p.options.DADConfigs, dadOptions) e.dad.mu.Unlock() e.mu.Unlock() diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 9a425e50a..85a8f9944 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -15,6 +15,7 @@ package ipv6_test import ( + "bytes" "testing" "time" @@ -119,11 +120,26 @@ func TestSendQueuedMLDReports(t *testing.T) { }, } + nonce := [...]byte{ + 1, 2, 3, 4, 5, 6, + } + + const maxNSMessages = 2 + secureRNGBytes := make([]byte, len(nonce)*maxNSMessages) + for b := secureRNGBytes[:]; len(b) > 0; b = b[len(nonce):] { + if n := copy(b, nonce[:]); n != len(nonce) { + t.Fatalf("got copy(...) = %d, want = %d", n, len(nonce)) + } + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits) clock := faketime.NewManualClock() + var secureRNG bytes.Reader + secureRNG.Reset(secureRNGBytes[:]) s := stack.New(stack.Options{ + SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ DupAddrDetectTransmits: test.dadTransmits, @@ -154,7 +170,7 @@ func TestSendQueuedMLDReports(t *testing.T) { checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(addr), - checker.NDPNSOptions(nil), + checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonce[:])}), )) } } diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index d9b728878..536493f87 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -1789,18 +1789,14 @@ func (ndp *ndpState) stopSolicitingRouters() { ndp.rtrSolicitTimer = timer{} } -func (ndp *ndpState) init(ep *endpoint) { +func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) { if ndp.defaultRouters != nil { panic("attempted to initialize NDP state twice") } ndp.ep = ep ndp.configs = ep.protocol.options.NDPConfigs - ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, ip.DADOptions{ - Clock: ep.protocol.stack.Clock(), - Protocol: ndp, - NICID: ep.nic.ID(), - }) + ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, dadOptions) ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState) ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState) ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState) @@ -1811,9 +1807,11 @@ func (ndp *ndpState) init(ep *endpoint) { } } -func (ndp *ndpState) SendDADMessage(addr tcpip.Address) tcpip.Error { +func (ndp *ndpState) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - return ndp.ep.sendNDPNS(header.IPv6Any, snmc, addr, header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* opts */) + return ndp.ep.sendNDPNS(header.IPv6Any, snmc, addr, header.EthernetAddressFromMulticastIPv6Address(snmc), header.NDPOptionsSerializer{ + header.NDPNonceOption(nonce), + }) } func (e *endpoint) sendNDPNS(srcAddr, dstAddr, targetAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, opts header.NDPOptionsSerializer) tcpip.Error { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 6e850fd46..52b9a200c 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -15,6 +15,7 @@ package ipv6 import ( + "bytes" "context" "strings" "testing" @@ -1264,8 +1265,21 @@ func TestCheckDuplicateAddress(t *testing.T) { DupAddrDetectTransmits: 1, RetransmitTimer: time.Second, } + + nonces := [...][]byte{ + {1, 2, 3, 4, 5, 6}, + {7, 8, 9, 10, 11, 12}, + } + + var secureRNGBytes []byte + for _, n := range nonces { + secureRNGBytes = append(secureRNGBytes, n...) + } + var secureRNG bytes.Reader + secureRNG.Reset(secureRNGBytes[:]) s := stack.New(stack.Options{ - Clock: clock, + SecureRNG: &secureRNG, + Clock: clock, NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ DADConfigs: dadConfigs, })}, @@ -1278,10 +1292,36 @@ func TestCheckDuplicateAddress(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - dadPacketsSent := 1 + dadPacketsSent := 0 + snmc := header.SolicitedNodeAddr(lladdr0) + remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc) + checkDADMsg := func() { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("expected %d-th DAD message", dadPacketsSent) + } + + if p.Proto != header.IPv6ProtocolNumber { + t.Errorf("(i=%d) got p.Proto = %d, want = %d", dadPacketsSent, p.Proto, header.IPv6ProtocolNumber) + } + + if p.Route.RemoteLinkAddress != remoteLinkAddr { + t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", dadPacketsSent, p.Route.RemoteLinkAddress, remoteLinkAddr) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(lladdr0), + checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}), + )) + } if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) } + checkDADMsg() // Start DAD for the address we just added. // @@ -1297,6 +1337,7 @@ func TestCheckDuplicateAddress(t *testing.T) { } else if res != stack.DADStarting { t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADStarting) } + checkDADMsg() // Remove the address and make sure our DAD request was not stopped. if err := s.RemoveAddress(nicID, lladdr0); err != nil { @@ -1328,33 +1369,6 @@ func TestCheckDuplicateAddress(t *testing.T) { default: } - snmc := header.SolicitedNodeAddr(lladdr0) - remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc) - - for i := 0; i < dadPacketsSent; i++ { - p, ok := e.Read() - if !ok { - t.Fatalf("expected %d-th DAD message", i) - } - - if p.Proto != header.IPv6ProtocolNumber { - t.Errorf("(i=%d) got p.Proto = %d, want = %d", i, p.Proto, header.IPv6ProtocolNumber) - } - - if p.Route.RemoteLinkAddress != remoteLinkAddr { - t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", i, p.Route.RemoteLinkAddress, remoteLinkAddr) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(lladdr0), - checker.NDPNSOptions(nil), - )) - } - // Should have no more packets. if p, ok := e.Read(); ok { t.Errorf("got unexpected packet = %#v", p) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 47796a6ba..43e6d102c 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "bytes" "context" "encoding/binary" "fmt" @@ -29,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" @@ -358,6 +360,66 @@ func TestDADDisabled(t *testing.T) { } } +func TestDADResolveLoopback(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + } + + dadConfigs := stack.DADConfigurations{ + RetransmitTimer: time.Second, + DupAddrDetectTransmits: 1, + } + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + Clock: clock, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + DADConfigs: dadConfigs, + })}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + } + if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + } + + // Address should not be considered bound to the NIC yet (DAD ongoing). + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) + } + + // DAD should not resolve after the normal resolution time since our DAD + // message was looped back - we should extend our DAD process. + dadResolutionTime := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer + clock.Advance(dadResolutionTime) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Error(err) + } + + // Make sure the address does not resolve before the extended resolution time + // has passed. + const delta = time.Nanosecond + // DAD will send extra NS probes if an NS message is looped back. + const extraTransmits = 3 + clock.Advance(dadResolutionTime*extraTransmits - delta) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Error(err) + } + + // DAD should now resolve. + clock.Advance(delta) + if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) + } +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -404,6 +466,16 @@ func TestDADResolve(t *testing.T) { }, } + nonces := [][]byte{ + {1, 2, 3, 4, 5, 6}, + {7, 8, 9, 10, 11, 12}, + } + + var secureRNGBytes []byte + for _, n := range nonces { + secureRNGBytes = append(secureRNGBytes, n...) + } + for _, test := range tests { test := test @@ -419,7 +491,12 @@ func TestDADResolve(t *testing.T) { headerLength: test.linkHeaderLen, } e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + var secureRNG bytes.Reader + secureRNG.Reset(secureRNGBytes) + s := stack.New(stack.Options{ + SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: stack.DADConfigurations{ @@ -553,7 +630,7 @@ func TestDADResolve(t *testing.T) { checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(addr1), - checker.NDPNSOptions(nil), + checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[i])}), )) if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 1fffe9274..11ff65bf2 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -23,6 +23,7 @@ import ( "bytes" "encoding/binary" "fmt" + "io" mathrand "math/rand" "sync/atomic" "time" @@ -445,6 +446,9 @@ type Stack struct { // used when a random number is required. randomGenerator *mathrand.Rand + // secureRNG is a cryptographically secure random number generator. + secureRNG io.Reader + // sendBufferSize holds the min/default/max send buffer sizes for // endpoints other than TCP. sendBufferSize tcpip.SendBufferSizeOption @@ -528,6 +532,9 @@ type Options struct { // IPTables are the initial iptables rules. If nil, iptables will allow // all traffic. IPTables *IPTables + + // SecureRNG is a cryptographically secure random number generator. + SecureRNG io.Reader } // TransportEndpointInfo holds useful information about a transport endpoint @@ -636,6 +643,10 @@ func New(opts Options) *Stack { opts.NUDConfigs.resetInvalidFields() + if opts.SecureRNG == nil { + opts.SecureRNG = rand.Reader + } + s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), @@ -652,6 +663,7 @@ func New(opts Options) *Stack { uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, randomGenerator: mathrand.New(randSrc), + secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, @@ -2048,6 +2060,12 @@ func (s *Stack) Rand() *mathrand.Rand { return s.randomGenerator } +// SecureRNG returns the stack's cryptographically secure random number +// generator. +func (s *Stack) SecureRNG() io.Reader { + return s.secureRNG +} + func generateRandUint32() uint32 { b := make([]byte, 4) if _, err := rand.Read(b); err != nil { |