diff options
Diffstat (limited to 'pkg/tcpip')
36 files changed, 4671 insertions, 797 deletions
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index a3485b35c..f2061c778 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -38,12 +38,16 @@ go_test( size = "small", srcs = [ "checksum_test.go", + "ipv6_test.go", "ipversion_test.go", "tcp_test.go", ], deps = [ ":header", + "//pkg/rand", + "//pkg/tcpip", "//pkg/tcpip/buffer", + "@com_github_google_go-cmp//cmp:go_default_library", ], ) @@ -55,5 +59,8 @@ go_test( "ndp_test.go", ], embed = [":header"], - deps = ["//pkg/tcpip"], + deps = [ + "//pkg/tcpip", + "@com_github_google_go-cmp//cmp:go_default_library", + ], ) diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 0caa51c1e..135a60b12 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -15,6 +15,7 @@ package header import ( + "crypto/sha256" "encoding/binary" "strings" @@ -90,6 +91,23 @@ const ( // IPv6Any is the non-routable IPv6 "any" meta address. It is also // known as the unspecified address. IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + // IIDSize is the size of an interface identifier (IID), in bytes, as + // defined by RFC 4291 section 2.5.1. + IIDSize = 8 + + // IIDOffsetInIPv6Address is the offset, in bytes, from the start + // of an IPv6 address to the beginning of the interface identifier + // (IID) for auto-generated addresses. That is, all bytes before + // the IIDOffsetInIPv6Address-th byte are the prefix bytes, and all + // bytes including and after the IIDOffsetInIPv6Address-th byte are + // for the IID. + IIDOffsetInIPv6Address = 8 + + // OpaqueIIDSecretKeyMinBytes is the recommended minimum number of bytes + // for the secret key used to generate an opaque interface identifier as + // outlined by RFC 7217. + OpaqueIIDSecretKeyMinBytes = 16 ) // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the @@ -266,27 +284,43 @@ func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address { return solicitedNodeMulticastPrefix + addr[len(addr)-3:] } +// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64 +// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1. +// +// buf MUST be at least 8 bytes. +func EthernetAdddressToModifiedEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) { + buf[0] = linkAddr[0] ^ 2 + buf[1] = linkAddr[1] + buf[2] = linkAddr[2] + buf[3] = 0xFF + buf[4] = 0xFE + buf[5] = linkAddr[3] + buf[6] = linkAddr[4] + buf[7] = linkAddr[5] +} + +// EthernetAddressToModifiedEUI64 computes a modified EUI-64 from a 48-bit +// Ethernet/MAC address, as per RFC 4291 section 2.5.1. +func EthernetAddressToModifiedEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte { + var buf [IIDSize]byte + EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) + return buf +} + // LinkLocalAddr computes the default IPv6 link-local address from a link-layer // (MAC) address. func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { - // Convert a 48-bit MAC to an EUI-64 and then prepend the link-local - // header, FE80::. + // Convert a 48-bit MAC to a modified EUI-64 and then prepend the + // link-local header, FE80::. // // The conversion is very nearly: // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff // Note the capital A. The conversion aa->Aa involves a bit flip. - lladdrb := [16]byte{ - 0: 0xFE, - 1: 0x80, - 8: linkAddr[0] ^ 2, - 9: linkAddr[1], - 10: linkAddr[2], - 11: 0xFF, - 12: 0xFE, - 13: linkAddr[3], - 14: linkAddr[4], - 15: linkAddr[5], + lladdrb := [IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, } + EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:]) return tcpip.Address(lladdrb[:]) } @@ -298,3 +332,42 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool { } return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } + +// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier +// (IID) to buf as outlined by RFC 7217 and returns the extended buffer. +// +// The opaque IID is generated from the cryptographic hash of the concatenation +// of the prefix, NIC's name, DAD counter (DAD retry counter) and the secret +// key. The secret key SHOULD be at least OpaqueIIDSecretKeyMinBytes bytes and +// MUST be generated to a pseudo-random number. See RFC 4086 for randomness +// requirements for security. +// +// If buf has enough capacity for the IID (IIDSize bytes), a new underlying +// array for the buffer will not be allocated. +func AppendOpaqueInterfaceIdentifier(buf []byte, prefix tcpip.Subnet, nicName string, dadCounter uint8, secretKey []byte) []byte { + // As per RFC 7217 section 5, the opaque identifier can be generated as a + // cryptographic hash of the concatenation of each of the function parameters. + // Note, we omit the optional Network_ID field. + h := sha256.New() + // h.Write never returns an error. + h.Write([]byte(prefix.ID()[:IIDOffsetInIPv6Address])) + h.Write([]byte(nicName)) + h.Write([]byte{dadCounter}) + h.Write(secretKey) + + var sumBuf [sha256.Size]byte + sum := h.Sum(sumBuf[:0]) + + return append(buf, sum[:IIDSize]...) +} + +// LinkLocalAddrWithOpaqueIID computes the default IPv6 link-local address with +// an opaque IID. +func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []byte) tcpip.Address { + lladdrb := [IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, + } + + return tcpip.Address(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey)) +} diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go new file mode 100644 index 000000000..1994003ed --- /dev/null +++ b/pkg/tcpip/header/ipv6_test.go @@ -0,0 +1,208 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "bytes" + "crypto/sha256" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + +func TestEthernetAdddressToModifiedEUI64(t *testing.T) { + expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6} + + if diff := cmp.Diff(expectedIID, header.EthernetAddressToModifiedEUI64(linkAddr)); diff != "" { + t.Errorf("EthernetAddressToModifiedEUI64(%s) mismatch (-want +got):\n%s", linkAddr, diff) + } + + var buf [header.IIDSize]byte + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) + if diff := cmp.Diff(expectedIID, buf); diff != "" { + t.Errorf("EthernetAddressToModifiedEUI64IntoBuf(%s, _) mismatch (-want +got):\n%s", linkAddr, diff) + } +} + +func TestLinkLocalAddr(t *testing.T) { + if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want { + t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want) + } +} + +func TestAppendOpaqueInterfaceIdentifier(t *testing.T) { + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte + if n, err := rand.Read(secretKeyBuf[:]); err != nil { + t.Fatalf("rand.Read(_): %s", err) + } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) + } + + tests := []struct { + name string + prefix tcpip.Subnet + nicName string + dadCounter uint8 + secretKey []byte + }{ + { + name: "SecretKey of minimum size", + prefix: header.IPv6LinkLocalPrefix.Subnet(), + nicName: "eth0", + dadCounter: 0, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], + }, + { + name: "SecretKey of less than minimum size", + prefix: func() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: "\x01\x02\x03\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + return addrWithPrefix.Subnet() + }(), + nicName: "eth10", + dadCounter: 1, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], + }, + { + name: "SecretKey of more than minimum size", + prefix: func() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: "\x01\x02\x03\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + return addrWithPrefix.Subnet() + }(), + nicName: "eth11", + dadCounter: 2, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], + }, + { + name: "Nil SecretKey and empty nicName", + prefix: func() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + return addrWithPrefix.Subnet() + }(), + nicName: "", + dadCounter: 3, + secretKey: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + h := sha256.New() + h.Write([]byte(test.prefix.ID()[:header.IIDOffsetInIPv6Address])) + h.Write([]byte(test.nicName)) + h.Write([]byte{test.dadCounter}) + if k := test.secretKey; k != nil { + h.Write(k) + } + var hashSum [sha256.Size]byte + h.Sum(hashSum[:0]) + want := hashSum[:header.IIDSize] + + // Passing a nil buffer should result in a new buffer returned with the + // IID. + if got := header.AppendOpaqueInterfaceIdentifier(nil, test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { + t.Errorf("got AppendOpaqueInterfaceIdentifier(nil, %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) + } + + // Passing a buffer with sufficient capacity for the IID should populate + // the buffer provided. + var iidBuf [header.IIDSize]byte + if got := header.AppendOpaqueInterfaceIdentifier(iidBuf[:0], test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { + t.Errorf("got AppendOpaqueInterfaceIdentifier(iidBuf[:0], %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) + } + if got := iidBuf[:]; !bytes.Equal(got, want) { + t.Errorf("got iidBuf = %x, want = %x", got, want) + } + }) + } +} + +func TestLinkLocalAddrWithOpaqueIID(t *testing.T) { + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte + if n, err := rand.Read(secretKeyBuf[:]); err != nil { + t.Fatalf("rand.Read(_): %s", err) + } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) + } + + prefix := header.IPv6LinkLocalPrefix.Subnet() + + tests := []struct { + name string + prefix tcpip.Subnet + nicName string + dadCounter uint8 + secretKey []byte + }{ + { + name: "SecretKey of minimum size", + nicName: "eth0", + dadCounter: 0, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], + }, + { + name: "SecretKey of less than minimum size", + nicName: "eth10", + dadCounter: 1, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], + }, + { + name: "SecretKey of more than minimum size", + nicName: "eth11", + dadCounter: 2, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], + }, + { + name: "Nil SecretKey and empty nicName", + nicName: "", + dadCounter: 3, + secretKey: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + addrBytes := [header.IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, + } + + want := tcpip.Address(header.AppendOpaqueInterfaceIdentifier( + addrBytes[:header.IIDOffsetInIPv6Address], + prefix, + test.nicName, + test.dadCounter, + test.secretKey, + )) + + if got := header.LinkLocalAddrWithOpaqueIID(test.nicName, test.dadCounter, test.secretKey); got != want { + t.Errorf("got LinkLocalAddrWithOpaqueIID(%s, %d, %x) = %s, want = %s", test.nicName, test.dadCounter, test.secretKey, got, want) + } + }) + } +} diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 1ca6199ef..06e0bace2 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -17,6 +17,7 @@ package header import ( "encoding/binary" "errors" + "math" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -85,6 +86,23 @@ const ( // within an NDPPrefixInformation. ndpPrefixInformationPrefixOffset = 14 + // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS + // Server option, as per RFC 8106 section 5.1. + NDPRecursiveDNSServerOptionType = 25 + + // ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte + // Lifetime field within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerLifetimeOffset = 2 + + // ndpRecursiveDNSServerAddressesOffset is the start of the addresses + // for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerAddressesOffset = 6 + + // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS + // Server option's length field value when it contains at least one + // IPv6 address. + minNDPRecursiveDNSServerLength = 3 + // lengthByteUnits is the multiplier factor for the Length field of an // NDP option. That is, the length field for NDP options is in units of // 8 octets, as per RFC 4861 section 4.6. @@ -92,13 +110,13 @@ const ( ) var ( - // NDPPrefixInformationInfiniteLifetime is a value that represents - // infinity for the Valid and Preferred Lifetime fields in a NDP Prefix - // Information option. Its value is (2^32 - 1)s = 4294967295s + // NDPInfiniteLifetime is a value that represents infinity for the + // 4-byte lifetime fields found in various NDP options. Its value is + // (2^32 - 1)s = 4294967295s. // // This is a variable instead of a constant so that tests can change // this value to a smaller value. It should only be modified by tests. - NDPPrefixInformationInfiniteLifetime = time.Second * 4294967295 + NDPInfiniteLifetime = time.Second * math.MaxUint32 ) // NDPOptionIterator is an iterator of NDPOption. @@ -118,6 +136,7 @@ var ( ErrNDPOptBufExhausted = errors.New("Buffer unexpectedly exhausted") ErrNDPOptZeroLength = errors.New("NDP option has zero-valued Length field") ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body") + ErrNDPInvalidLength = errors.New("NDP option's Length value is invalid as per relevant RFC") ) // Next returns the next element in the backing NDPOptions, or true if we are @@ -182,6 +201,22 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { } return NDPPrefixInformation(body), false, nil + + case NDPRecursiveDNSServerOptionType: + // RFC 8106 section 5.3.1 outlines that the RDNSS option + // must have a minimum length of 3 so it contains at + // least one IPv6 address. + if l < minNDPRecursiveDNSServerLength { + return nil, true, ErrNDPInvalidLength + } + + opt := NDPRecursiveDNSServer(body) + if len(opt.Addresses()) == 0 { + return nil, true, ErrNDPOptMalformedBody + } + + return opt, false, nil + default: // We do not yet recognize the option, just skip for // now. This is okay because RFC 4861 allows us to @@ -434,7 +469,7 @@ func (o NDPPrefixInformation) AutonomousAddressConfigurationFlag() bool { // // Note, a value of 0 implies the prefix should not be considered as on-link, // and a value of infinity/forever is represented by -// NDPPrefixInformationInfiniteLifetime. +// NDPInfiniteLifetime. func (o NDPPrefixInformation) ValidLifetime() time.Duration { // The field is the time in seconds, as per RFC 4861 section 4.6.2. return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationValidLifetimeOffset:])) @@ -447,7 +482,7 @@ func (o NDPPrefixInformation) ValidLifetime() time.Duration { // // Note, a value of 0 implies that addresses generated from the prefix should // no longer remain preferred, and a value of infinity is represented by -// NDPPrefixInformationInfiniteLifetime. +// NDPInfiniteLifetime. // // Also note that the value of this field MUST NOT exceed the Valid Lifetime // field to avoid preferring addresses that are no longer valid, for the @@ -476,3 +511,79 @@ func (o NDPPrefixInformation) Subnet() tcpip.Subnet { } return addrWithPrefix.Subnet() } + +// NDPRecursiveDNSServer is the NDP Recursive DNS Server option, as defined by +// RFC 8106 section 5.1. +// +// To make sure that the option meets its minimum length and does not end in the +// middle of a DNS server's IPv6 address, the length of a valid +// NDPRecursiveDNSServer must meet the following constraint: +// (Length - ndpRecursiveDNSServerAddressesOffset) % IPv6AddressSize == 0 +type NDPRecursiveDNSServer []byte + +// Type returns the type of an NDP Recursive DNS Server option. +// +// Type implements NDPOption.Type. +func (NDPRecursiveDNSServer) Type() uint8 { + return NDPRecursiveDNSServerOptionType +} + +// Length implements NDPOption.Length. +func (o NDPRecursiveDNSServer) Length() int { + return len(o) +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the reserved bytes that are before the Lifetime field. + for i := 0; i < ndpRecursiveDNSServerLifetimeOffset; i++ { + b[i] = 0 + } + + return used +} + +// Lifetime returns the length of time that the DNS server addresses +// in this option may be used for name resolution. +// +// Note, a value of 0 implies the addresses should no longer be used, +// and a value of infinity/forever is represented by NDPInfiniteLifetime. +// +// Lifetime may panic if o does not have enough bytes to hold the Lifetime +// field. +func (o NDPRecursiveDNSServer) Lifetime() time.Duration { + // The field is the time in seconds, as per RFC 8106 section 5.1. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRecursiveDNSServerLifetimeOffset:])) +} + +// Addresses returns the recursive DNS server IPv6 addresses that may be +// used for name resolution. +// +// Note, some of the addresses returned MAY be link-local addresses. +// +// Addresses may panic if o does not hold valid IPv6 addresses. +func (o NDPRecursiveDNSServer) Addresses() []tcpip.Address { + l := len(o) + if l < ndpRecursiveDNSServerAddressesOffset { + return nil + } + + l -= ndpRecursiveDNSServerAddressesOffset + if l%IPv6AddressSize != 0 { + return nil + } + + buf := o[ndpRecursiveDNSServerAddressesOffset:] + var addrs []tcpip.Address + for len(buf) > 0 { + addr := tcpip.Address(buf[:IPv6AddressSize]) + if !IsV6UnicastAddress(addr) { + return nil + } + addrs = append(addrs, addr) + buf = buf[IPv6AddressSize:] + } + return addrs +} diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index ad6daafcd..2c439d70c 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -369,6 +370,175 @@ func TestNDPPrefixInformationOption(t *testing.T) { } } +func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) { + b := []byte{ + 9, 8, + 1, 2, 4, 8, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + expected := []byte{ + 25, 3, 0, 0, + 1, 2, 4, 8, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + opts := NDPOptions(targetBuf) + serializer := NDPOptionsSerializer{ + NDPRecursiveDNSServer(b), + } + if got, want := opts.Serialize(serializer), len(expected); got != want { + t.Errorf("got Serialize = %d, want = %d", got, want) + } + if !bytes.Equal(targetBuf, expected) { + t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPRecursiveDNSServerOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) + } + + opt, ok := next.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) + } + if got := opt.Type(); got != 25 { + t.Errorf("got Type = %d, want = 31", got) + } + if got := opt.Length(); got != 22 { + t.Errorf("got Length = %d, want = 22", got) + } + if got, want := opt.Lifetime(), 16909320*time.Second; got != want { + t.Errorf("got Lifetime = %s, want = %s", got, want) + } + want := []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + } + if got := opt.Addresses(); !cmp.Equal(got, want) { + t.Errorf("got Addresses = %v, want = %v", got, want) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + +func TestNDPRecursiveDNSServerOption(t *testing.T) { + tests := []struct { + name string + buf []byte + lifetime time.Duration + addrs []tcpip.Address + }{ + { + "Valid1Addr", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + }, + }, + { + "Valid2Addr", + []byte{ + 25, 5, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", + }, + }, + { + "Valid3Addr", + []byte{ + 25, 7, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11", + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + // Iterator should get our option. + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPRecursiveDNSServerOptionType { + t.Fatalf("got Type %= %d, want = %d", got, NDPRecursiveDNSServerOptionType) + } + + opt, ok := next.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) + } + if got := opt.Lifetime(); got != test.lifetime { + t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) + } + if got := opt.Addresses(); !cmp.Equal(got, test.addrs) { + t.Errorf("got Addresses = %v, want = %v", got, test.addrs) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } + }) + } +} + // TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions // the iterator was returned for is malformed. func TestNDPOptionsIterCheck(t *testing.T) { @@ -473,6 +643,51 @@ func TestNDPOptionsIterCheck(t *testing.T) { }, nil, }, + { + "InvalidRecursiveDNSServerCutsOffAddress", + []byte{ + 25, 4, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 1, 2, 3, 4, 5, 6, 7, + }, + ErrNDPOptMalformedBody, + }, + { + "InvalidRecursiveDNSServerInvalidLengthField", + []byte{ + 25, 2, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, + }, + ErrNDPInvalidLength, + }, + { + "RecursiveDNSServerTooSmall", + []byte{ + 25, 1, 0, 0, + 0, 0, 0, + }, + ErrNDPOptBufExhausted, + }, + { + "RecursiveDNSServerMulticast", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + }, + ErrNDPOptMalformedBody, + }, + { + "RecursiveDNSServerUnspecified", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, + ErrNDPOptMalformedBody, + }, } for _, test := range tests { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index da8482509..42cacb8a6 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -79,16 +79,16 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketLooping, tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4144a7837..f1bc33adf 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -239,7 +239,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -480,7 +480,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e645cf62c..4ee3d5b45 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -238,11 +238,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in // the PacketBuffer's Data field. views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) @@ -256,7 +256,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw loopedR.Release() } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { @@ -270,11 +270,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { - if loop&stack.PacketLoop != 0 { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { + if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return len(pkts), nil } @@ -289,7 +289,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. ip := header.IPv4(pkt.Data.First()) @@ -324,10 +324,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLo ip.SetChecksum(0) ip.SetChecksum(^ip.CalculateChecksum()) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { e.HandlePacket(r, pkt.Clone()) } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index dd31f0fb7..58c3c79b9 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -112,11 +112,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in // the PacketBuffer's Data field. views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) @@ -130,7 +130,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw loopedR.Release() } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } @@ -139,11 +139,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { - if loop&stack.PacketLoop != 0 { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { + if r.Loop&stack.PacketLoop != 0 { panic("not implemented") } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return len(pkts), nil } @@ -161,8 +161,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { - // TODO(b/119580726): Support IPv6 header-included packets. +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { + // TODO(b/146666412): Support IPv6 header-included packets. return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 4839f0a65..e156b01f6 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -1,5 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 30cea8996..6c5e19e8f 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -41,6 +41,30 @@ type portDescriptor struct { port uint16 } +// Flags represents the type of port reservation. +// +// +stateify savable +type Flags struct { + // MostRecent represents UDP SO_REUSEADDR. + MostRecent bool + + // LoadBalanced indicates SO_REUSEPORT. + // + // LoadBalanced takes precidence over MostRecent. + LoadBalanced bool +} + +func (f Flags) bits() reuseFlag { + var rf reuseFlag + if f.MostRecent { + rf |= mostRecentFlag + } + if f.LoadBalanced { + rf |= loadBalancedFlag + } + return rf +} + // PortManager manages allocating, reserving and releasing ports. type PortManager struct { mu sync.RWMutex @@ -54,9 +78,59 @@ type PortManager struct { hint uint32 } +type reuseFlag int + +const ( + mostRecentFlag reuseFlag = 1 << iota + loadBalancedFlag + nextFlag + + flagMask = nextFlag - 1 +) + type portNode struct { - reuse bool - refs int + // refs stores the count for each possible flag combination. + refs [nextFlag]int +} + +func (p portNode) totalRefs() int { + var total int + for _, r := range p.refs { + total += r + } + return total +} + +// flagRefs returns the number of references with all specified flags. +func (p portNode) flagRefs(flags reuseFlag) int { + var total int + for i, r := range p.refs { + if reuseFlag(i)&flags == flags { + total += r + } + } + return total +} + +// allRefsHave returns if all references have all specified flags. +func (p portNode) allRefsHave(flags reuseFlag) bool { + for i, r := range p.refs { + if reuseFlag(i)&flags == flags && r > 0 { + return false + } + } + return true +} + +// intersectionRefs returns the set of flags shared by all references. +func (p portNode) intersectionRefs() reuseFlag { + intersection := flagMask + for i, r := range p.refs { + if r > 0 { + intersection &= reuseFlag(i) + } + } + return intersection } // deviceNode is never empty. When it has no elements, it is removed from the @@ -66,30 +140,44 @@ type deviceNode map[tcpip.NICID]portNode // isAvailable checks whether binding is possible by device. If not binding to a // device, check against all portNodes. If binding to a specific device, check // against the unspecified device and the provided device. -func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool { +// +// If either of the port reuse flags is enabled on any of the nodes, all nodes +// sharing a port must share at least one reuse flag. This matches Linux's +// behavior. +func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { + flagBits := flags.bits() if bindToDevice == 0 { // Trying to binding all devices. - if !reuse { + if flagBits == 0 { // Can't bind because the (addr,port) is already bound. return false } + intersection := flagMask for _, p := range d { - if !p.reuse { - // Can't bind because the (addr,port) was previously bound without reuse. + i := p.intersectionRefs() + intersection &= i + if intersection&flagBits == 0 { + // Can't bind because the (addr,port) was + // previously bound without reuse. return false } } return true } + intersection := flagMask + if p, ok := d[0]; ok { - if !reuse || !p.reuse { + intersection = p.intersectionRefs() + if intersection&flagBits == 0 { return false } } if p, ok := d[bindToDevice]; ok { - if !reuse || !p.reuse { + i := p.intersectionRefs() + intersection &= i + if intersection&flagBits == 0 { return false } } @@ -103,12 +191,12 @@ type bindAddresses map[tcpip.Address]deviceNode // isAvailable checks whether an IP address is available to bind to. If the // address is the "any" address, check all other addresses. Otherwise, just // check against the "any" address and the provided address. -func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool { +func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID) bool { if addr == anyIPAddress { // If binding to the "any" address then check that there are no conflicts // with all addresses. for _, d := range b { - if !d.isAvailable(reuse, bindToDevice) { + if !d.isAvailable(flags, bindToDevice) { return false } } @@ -117,14 +205,14 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice // Check that there is no conflict with the "any" address. if d, ok := b[anyIPAddress]; ok { - if !d.isAvailable(reuse, bindToDevice) { + if !d.isAvailable(flags, bindToDevice) { return false } } // Check that this is no conflict with the provided address. if d, ok := b[addr]; ok { - if !d.isAvailable(reuse, bindToDevice) { + if !d.isAvailable(flags, bindToDevice) { return false } } @@ -190,17 +278,17 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) + return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, reuse, bindToDevice) { + if !addrs.isAvailable(addr, flags, bindToDevice) { return false } } @@ -212,14 +300,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) { + if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice) { return 0, tcpip.ErrPortInUse } return port, nil @@ -227,15 +315,16 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil + return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) { return false } + flagBits := flags.bits() // Reserve port on all network protocols. for _, network := range networks { @@ -250,12 +339,9 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber d = make(deviceNode) m[addr] = d } - if n, ok := d[bindToDevice]; ok { - n.refs++ - d[bindToDevice] = n - } else { - d[bindToDevice] = portNode{reuse: reuse, refs: 1} - } + n := d[bindToDevice] + n.refs[flagBits]++ + d[bindToDevice] = n } return true @@ -263,10 +349,12 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) { +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) { s.mu.Lock() defer s.mu.Unlock() + flagBits := flags.bits() + for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { @@ -278,9 +366,9 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp if !ok { continue } - n.refs-- + n.refs[flagBits]-- d[bindToDevice] = n - if n.refs == 0 { + if n.refs == [nextFlag]int{} { delete(d, bindToDevice) } if len(d) == 0 { diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 19f4833fc..d6969d050 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -33,7 +33,7 @@ type portReserveTestAction struct { port uint16 ip tcpip.Address want *tcpip.Error - reuse bool + flags Flags release bool device tcpip.NICID } @@ -50,7 +50,7 @@ func TestPortReservation(t *testing.T) { {port: 80, ip: fakeIPAddress1, want: nil}, /* N.B. Order of tests matters! */ {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse}, - {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, + {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, }, }, { @@ -61,7 +61,7 @@ func TestPortReservation(t *testing.T) { /* release fakeIPAddress, but anyIPAddress is still inuse */ {port: 22, ip: fakeIPAddress, release: true}, {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, + {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */ {port: 22, ip: anyIPAddress, want: nil, release: true}, {port: 22, ip: fakeIPAddress, want: nil}, @@ -71,36 +71,36 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 00, ip: fakeIPAddress, want: nil}, {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { tname: "bind to ip with reuseport", actions: []portReserveTestAction{ - {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, - {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 25, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, reuse: true, want: nil}, + {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { tname: "bind to inaddr any with reuseport", actions: []portReserveTestAction{ - {port: 24, ip: anyIPAddress, reuse: true, want: nil}, - {port: 24, ip: anyIPAddress, reuse: true, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, release: true, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil}, - {port: 24, ip: anyIPAddress, release: true}, - {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: anyIPAddress, release: true}, - {port: 24, ip: anyIPAddress, reuse: false, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil}, }, }, { tname: "bind twice with device fails", @@ -125,88 +125,152 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { tname: "bind with device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 789, want: nil}, {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { - tname: "bind with reuse", + tname: "bind with reuseport", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { - tname: "binding with reuse and device", + tname: "binding with reuseport and device", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, }, }, { - tname: "mixing reuse and not reuse by binding to device", + tname: "mixing reuseport and not reuseport by binding to device", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 999, want: nil}, }, }, { - tname: "can't bind to 0 after mixing reuse and not reuse", + tname: "can't bind to 0 after mixing reuseport and not reuseport", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { tname: "bind and release", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, // Release the bind to device 0 and try again. - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil}, }, }, { - tname: "bind twice with reuse once", + tname: "bind twice with reuseport once", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { tname: "release an unreserved device", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil}, // The below don't exist. - {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true}, - {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true}, + {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, // Release all. - {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true}, + }, + }, { + tname: "bind with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil}, + }, + }, { + tname: "bind twice with reuseaddr once", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr and reuseport, and then reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport, and then reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport twice, and then reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr and reuseport twice, and then reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr, and then reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseport, and then reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, }, }, } { @@ -216,12 +280,12 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device) + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want) + t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d) = %v, want %v", test.ip, test.port, test.flags, test.device, err, test.want) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index a57752a7c..d7496fde6 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_binary( name = "tun_tcp_connect", srcs = ["main.go"], + visibility = ["//:sandbox"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD index dad8ef399..875561566 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/BUILD +++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_binary( name = "tun_tcp_echo", srcs = ["main.go"], + visibility = ["//:sandbox"], deps = [ "//pkg/tcpip", "//pkg/tcpip/link/fdbased", diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 69077669a..b8f9517d0 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -59,6 +59,7 @@ go_test( ], deps = [ ":stack", + "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index cfdd0496e..4722ec9ce 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -58,6 +58,14 @@ const ( // Default = true. defaultDiscoverOnLinkPrefixes = true + // defaultAutoGenGlobalAddresses is the default configuration for + // whether or not to generate global IPv6 addresses in response to + // receiving a new Prefix Information option with its Autonomous + // Address AutoConfiguration flag set, as a host. + // + // Default = true. + defaultAutoGenGlobalAddresses = true + // minimumRetransmitTimer is the minimum amount of time to wait between // sending NDP Neighbor solicitation messages. Note, RFC 4861 does // not impose a minimum Retransmit Timer, but we do here to make sure @@ -87,6 +95,24 @@ const ( // // Max = 10. MaxDiscoveredOnLinkPrefixes = 10 + + // validPrefixLenForAutoGen is the expected prefix length that an + // address can be generated for. Must be 64 bits as the interface + // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so + // 128 - 64 = 64. + validPrefixLenForAutoGen = 64 +) + +var ( + // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid + // Lifetime to update the valid lifetime of a generated address by + // SLAAC. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // Min = 2hrs. + MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour ) // NDPDispatcher is the interface integrators of netstack must implement to @@ -105,40 +131,70 @@ type NDPDispatcher interface { OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) // OnDefaultRouterDiscovered will be called when a new default router is - // discovered. Implementations must return true along with a new valid - // route table if the newly discovered router should be remembered. If - // an implementation returns false, the second return value will be - // ignored. + // discovered. Implementations must return true if the newly discovered + // router should be remembered. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) + OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool // OnDefaultRouterInvalidated will be called when a discovered default - // router is invalidated. Implementers must return a new valid route - // table. + // router that was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route + OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is - // discovered. Implementations must return true along with a new valid - // route table if the newly discovered on-link prefix should be - // remembered. If an implementation returns false, the second return - // value will be ignored. + // discovered. Implementations must return true if the newly discovered + // on-link prefix should be remembered. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) (bool, []tcpip.Route) + OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool // OnOnLinkPrefixInvalidated will be called when a discovered on-link - // prefix is invalidated. Implementers must return a new valid route - // table. + // prefix that was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) []tcpip.Route + OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) + + // OnAutoGenAddress will be called when a new prefix with its + // autonomous address-configuration flag set has been received and SLAAC + // has been performed. Implementations may prevent the stack from + // assigning the address to the NIC by returning false. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool + + // OnAutoGenAddressDeprecated will be called when an auto-generated + // address (as part of SLAAC) has been deprecated, but is still + // considered valid. Note, if an address is invalidated at the same + // time it is deprecated, the deprecation event MAY be omitted. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) + + // OnAutoGenAddressInvalidated will be called when an auto-generated + // address (as part of SLAAC) has been invalidated. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) + + // OnRecursiveDNSServerOption will be called when an NDP option with + // recursive DNS servers has been received. Note, addrs may contain + // link-local addresses. + // + // It is up to the caller to use the DNS Servers only for their valid + // lifetime. OnRecursiveDNSServerOption may be called for new or + // already known DNS servers. If called with known DNS servers, their + // valid lifetimes must be refreshed to lifetime (it may be increased, + // decreased, or completely invalidated when lifetime = 0). + OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) } // NDPConfigurations is the NDP configurations for the netstack. @@ -168,6 +224,17 @@ type NDPConfigurations struct { // will be discovered from Router Advertisements' Prefix Information // option. This configuration is ignored if HandleRAs is false. DiscoverOnLinkPrefixes bool + + // AutoGenGlobalAddresses determines whether or not global IPv6 + // addresses will be generated for a NIC in response to receiving a new + // Prefix Information option with its Autonomous Address + // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC). + // + // Note, if an address was already generated for some unique prefix, as + // part of SLAAC, this option does not affect whether or not the + // lifetime(s) of the generated address changes; this option only + // affects the generation of new addresses as part of SLAAC. + AutoGenGlobalAddresses bool } // DefaultNDPConfigurations returns an NDPConfigurations populated with @@ -179,6 +246,7 @@ func DefaultNDPConfigurations() NDPConfigurations { HandleRAs: defaultHandleRAs, DiscoverDefaultRouters: defaultDiscoverDefaultRouters, DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, + AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, } } @@ -210,6 +278,9 @@ type ndpState struct { // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState + + // The addresses generated by SLAAC. + autoGenAddresses map[tcpip.Address]autoGenAddressState } // dadState holds the Duplicate Address Detection timer and channel to signal @@ -270,6 +341,43 @@ type onLinkPrefixState struct { doNotInvalidate *bool } +// autoGenAddressState holds data associated with an address generated via +// SLAAC. +type autoGenAddressState struct { + // A reference to the referencedNetworkEndpoint that this autoGenAddressState + // is holding state for. + ref *referencedNetworkEndpoint + + deprecationTimer *time.Timer + + // Used to signal the timer not to deprecate the SLAAC address in a race + // condition. Used for the same reason as doNotInvalidate, but for deprecating + // an address. + doNotDeprecate *bool + + invalidationTimer *time.Timer + + // Used to signal the timer not to invalidate the SLAAC address (A) in + // a race condition (T1 is a goroutine that handles a PI for A and T2 + // is the goroutine that handles A's invalidation timer firing): + // T1: Receive a new PI for A + // T1: Obtain the NIC's lock before processing the PI + // T2: A's invalidation timer fires, and gets blocked on obtaining the + // NIC's lock + // T1: Refreshes/extends A's lifetime & releases NIC's lock + // T2: Obtains NIC's lock & invalidates A immediately + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using doNotInvalidate to not invalidate A, so that + // once T2 obtains the lock, it will see that it is set to true and do + // nothing further. + doNotInvalidate *bool + + // Nonzero only when the address is not valid forever (invalidationTimer + // is not nil). + validUntil time.Time +} + // startDuplicateAddressDetection performs Duplicate Address Detection. // // This function must only be called by IPv6 addresses that are currently @@ -534,19 +642,21 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // we do not check the iterator for errors on calls to Next. it, _ := ra.Options().Iter(false) for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { - switch opt.Type() { - case header.NDPPrefixInformationType: - if !ndp.configs.DiscoverOnLinkPrefixes { + switch opt := opt.(type) { + case header.NDPRecursiveDNSServer: + if ndp.nic.stack.ndpDisp == nil { continue } - pi := opt.(header.NDPPrefixInformation) + ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), opt.Addresses(), opt.Lifetime()) - prefix := pi.Subnet() + case header.NDPPrefixInformation: + prefix := opt.Subnet() // Is the prefix a link-local? if header.IsV6LinkLocalAddress(prefix.ID()) { - // ...Yes, skip as per RFC 4861 section 6.3.4. + // ...Yes, skip as per RFC 4861 section 6.3.4, + // and RFC 4862 section 5.5.3.b (for SLAAC). continue } @@ -557,82 +667,13 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { continue } - if !pi.OnLinkFlag() { - // Not on-link so don't "discover" it as an - // on-link prefix. - continue - } - - prefixState, ok := ndp.onLinkPrefixes[prefix] - vl := pi.ValidLifetime() - switch { - case !ok && vl == 0: - // Don't know about this prefix but has a zero - // valid lifetime, so just ignore. - continue - - case !ok && vl != 0: - // This is a new on-link prefix we are - // discovering. - // - // Only remember it if we currently know about - // less than MaxDiscoveredOnLinkPrefixes on-link - // prefixes. - if len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes { - ndp.rememberOnLinkPrefix(prefix, vl) - } - continue - - case ok && vl == 0: - // We know about the on-link prefix, but it is - // no longer to be considered on-link, so - // invalidate it. - ndp.invalidateOnLinkPrefix(prefix) - continue + if opt.OnLinkFlag() { + ndp.handleOnLinkPrefixInformation(opt) } - // This is an already discovered on-link prefix with a - // new non-zero valid lifetime. - // Update the invalidation timer. - timer := prefixState.invalidationTimer - - if timer == nil && vl >= header.NDPPrefixInformationInfiniteLifetime { - // Had infinite valid lifetime before and - // continues to have an invalid lifetime. Do - // nothing further. - continue + if opt.AutonomousAddressConfigurationFlag() { + ndp.handleAutonomousPrefixInformation(opt) } - - if timer != nil && !timer.Stop() { - // If we reach this point, then we know the - // timer already fired after we took the NIC - // lock. Inform the timer to not invalidate - // the prefix once it obtains the lock as we - // just got a new PI that refeshes its lifetime - // to a non-zero value. See - // onLinkPrefixState.doNotInvalidate for more - // details. - *prefixState.doNotInvalidate = true - } - - if vl >= header.NDPPrefixInformationInfiniteLifetime { - // Prefix is now valid forever so we don't need - // an invalidation timer. - prefixState.invalidationTimer = nil - ndp.onLinkPrefixes[prefix] = prefixState - continue - } - - if timer != nil { - // We already have a timer so just reset it to - // expire after the new valid lifetime. - timer.Reset(vl) - continue - } - - // We do not have a timer so just create a new one. - prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate) - ndp.onLinkPrefixes[prefix] = prefixState } // TODO(b/141556115): Do (MTU) Parameter Discovery. @@ -641,7 +682,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // invalidateDefaultRouter invalidates a discovered default router. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { rtr, ok := ndp.defaultRouters[ip] @@ -659,8 +700,8 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. - if ndp.nic.stack.ndpDisp != nil { - ndp.nic.stack.routeTable = ndp.nic.stack.ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) } } @@ -669,15 +710,15 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { // // The router identified by ip MUST NOT already be known by the NIC. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { - if ndp.nic.stack.ndpDisp == nil { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { return } // Inform the integrator when we discovered a default router. - remember, routeTable := ndp.nic.stack.ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) - if !remember { + if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) { // Informed by the integrator to not remember the router, do // nothing further. return @@ -704,8 +745,6 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { }), doNotInvalidate: &doNotInvalidate, } - - ndp.nic.stack.routeTable = routeTable } // rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6 @@ -713,15 +752,15 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { // // The prefix identified by prefix MUST NOT already be known. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { - if ndp.nic.stack.ndpDisp == nil { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { return } // Inform the integrator when we discovered an on-link prefix. - remember, routeTable := ndp.nic.stack.ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) - if !remember { + if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) { // Informed by the integrator to not remember the prefix, do // nothing further. return @@ -734,7 +773,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) var timer *time.Timer // Only create a timer if the lifetime is not infinite. - if l < header.NDPPrefixInformationInfiniteLifetime { + if l < header.NDPInfiniteLifetime { timer = ndp.prefixInvalidationCallback(prefix, l, &doNotInvalidate) } @@ -742,13 +781,11 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) invalidationTimer: timer, doNotInvalidate: &doNotInvalidate, } - - ndp.nic.stack.routeTable = routeTable } // invalidateOnLinkPrefix invalidates a discovered on-link prefix. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { s, ok := ndp.onLinkPrefixes[prefix] @@ -769,8 +806,8 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. - if ndp.nic.stack.ndpDisp != nil { - ndp.nic.stack.routeTable = ndp.nic.stack.ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix) + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix) } } @@ -795,3 +832,466 @@ func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Dur ndp.invalidateOnLinkPrefix(prefix) }) } + +// handleOnLinkPrefixInformation handles a Prefix Information option with +// its on-link flag set, as per RFC 4861 section 6.3.4. +// +// handleOnLinkPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the on-link flag is set. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) { + prefix := pi.Subnet() + prefixState, ok := ndp.onLinkPrefixes[prefix] + vl := pi.ValidLifetime() + + if !ok && vl == 0 { + // Don't know about this prefix but it has a zero valid + // lifetime, so just ignore. + return + } + + if !ok && vl != 0 { + // This is a new on-link prefix we are discovering + // + // Only remember it if we currently know about less than + // MaxDiscoveredOnLinkPrefixes on-link prefixes. + if ndp.configs.DiscoverOnLinkPrefixes && len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes { + ndp.rememberOnLinkPrefix(prefix, vl) + } + return + } + + if ok && vl == 0 { + // We know about the on-link prefix, but it is + // no longer to be considered on-link, so + // invalidate it. + ndp.invalidateOnLinkPrefix(prefix) + return + } + + // This is an already discovered on-link prefix with a + // new non-zero valid lifetime. + // Update the invalidation timer. + timer := prefixState.invalidationTimer + + if timer == nil && vl >= header.NDPInfiniteLifetime { + // Had infinite valid lifetime before and + // continues to have an invalid lifetime. Do + // nothing further. + return + } + + if timer != nil && !timer.Stop() { + // If we reach this point, then we know the timer alread fired + // after we took the NIC lock. Inform the timer to not + // invalidate the prefix once it obtains the lock as we just + // got a new PI that refreshes its lifetime to a non-zero value. + // See onLinkPrefixState.doNotInvalidate for more details. + *prefixState.doNotInvalidate = true + } + + if vl >= header.NDPInfiniteLifetime { + // Prefix is now valid forever so we don't need + // an invalidation timer. + prefixState.invalidationTimer = nil + ndp.onLinkPrefixes[prefix] = prefixState + return + } + + if timer != nil { + // We already have a timer so just reset it to + // expire after the new valid lifetime. + timer.Reset(vl) + return + } + + // We do not have a timer so just create a new one. + prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate) + ndp.onLinkPrefixes[prefix] = prefixState +} + +// handleAutonomousPrefixInformation handles a Prefix Information option with +// its autonomous flag set, as per RFC 4862 section 5.5.3. +// +// handleAutonomousPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the autonomous flag is set. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) { + vl := pi.ValidLifetime() + pl := pi.PreferredLifetime() + + // If the preferred lifetime is greater than the valid lifetime, + // silently ignore the Prefix Information option, as per RFC 4862 + // section 5.5.3.c. + if pl > vl { + return + } + + prefix := pi.Subnet() + + // Check if we already have an auto-generated address for prefix. + for addr, addrState := range ndp.autoGenAddresses { + refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: addrState.ref.ep.PrefixLen()} + if refAddrWithPrefix.Subnet() != prefix { + continue + } + + // At this point, we know we are refreshing a SLAAC generated IPv6 address + // with the prefix prefix. Do the work as outlined by RFC 4862 section + // 5.5.3.e. + ndp.refreshAutoGenAddressLifetimes(addr, pl, vl) + return + } + + // We do not already have an address within the prefix, prefix. Do the + // work as outlined by RFC 4862 section 5.5.3.d if n is configured + // to auto-generated global addresses by SLAAC. + ndp.newAutoGenAddress(prefix, pl, vl) +} + +// newAutoGenAddress generates a new SLAAC address with the provided lifetimes +// for prefix. +// +// pl is the new preferred lifetime. vl is the new valid lifetime. +func (ndp *ndpState) newAutoGenAddress(prefix tcpip.Subnet, pl, vl time.Duration) { + // Are we configured to auto-generate new global addresses? + if !ndp.configs.AutoGenGlobalAddresses { + return + } + + // If we do not already have an address for this prefix and the valid + // lifetime is 0, no need to do anything further, as per RFC 4862 + // section 5.5.3.d. + if vl == 0 { + return + } + + // Make sure the prefix is valid (as far as its length is concerned) to + // generate a valid IPv6 address from an interface identifier (IID), as + // per RFC 4862 sectiion 5.5.3.d. + if prefix.Prefix() != validPrefixLenForAutoGen { + return + } + + addrBytes := []byte(prefix.ID()) + if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + addrBytes = header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), 0 /* dadCounter */, oIID.SecretKey) + } else { + // Only attempt to generate an interface-specific IID if we have a valid + // link address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + linkAddr := ndp.nic.linkEP.LinkAddress() + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return + } + + // Generate an address within prefix from the modified EUI-64 of ndp's NIC's + // Ethernet MAC address. + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + } + addr := tcpip.Address(addrBytes) + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + } + + // If the nic already has this address, do nothing further. + if ndp.nic.hasPermanentAddrLocked(addr) { + return + } + + // Inform the integrator that we have a new SLAAC address. + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return + } + if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) { + // Informed by the integrator not to add the address. + return + } + + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + } + // If the preferred lifetime is zero, then the address should be considered + // deprecated. + deprecated := pl == 0 + ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) + if err != nil { + log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err) + } + + // Setup the timers to deprecate and invalidate this newly generated address. + + var doNotDeprecate bool + var pTimer *time.Timer + if !deprecated && pl < header.NDPInfiniteLifetime { + pTimer = ndp.autoGenAddrDeprecationTimer(addr, pl, &doNotDeprecate) + } + + var doNotInvalidate bool + var vTimer *time.Timer + if vl < header.NDPInfiniteLifetime { + vTimer = ndp.autoGenAddrInvalidationTimer(addr, vl, &doNotInvalidate) + } + + ndp.autoGenAddresses[addr] = autoGenAddressState{ + ref: ref, + deprecationTimer: pTimer, + doNotDeprecate: &doNotDeprecate, + invalidationTimer: vTimer, + doNotInvalidate: &doNotInvalidate, + validUntil: time.Now().Add(vl), + } +} + +// refreshAutoGenAddressLifetimes refreshes the lifetime of a SLAAC generated +// address addr. +// +// pl is the new preferred lifetime. vl is the new valid lifetime. +func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl time.Duration) { + addrState, ok := ndp.autoGenAddresses[addr] + if !ok { + log.Fatalf("ndp: SLAAC state not found to refresh lifetimes for %s", addr) + } + defer func() { ndp.autoGenAddresses[addr] = addrState }() + + // If the preferred lifetime is zero, then the address should be considered + // deprecated. + deprecated := pl == 0 + wasDeprecated := addrState.ref.deprecated + addrState.ref.deprecated = deprecated + + // Only send the deprecation event if the deprecated status for addr just + // changed from non-deprecated to deprecated. + if !wasDeprecated && deprecated { + ndp.notifyAutoGenAddressDeprecated(addr) + } + + // If addr was preferred for some finite lifetime before, stop the deprecation + // timer so it can be reset. + if t := addrState.deprecationTimer; t != nil && !t.Stop() { + *addrState.doNotDeprecate = true + } + + // Reset the deprecation timer. + if pl >= header.NDPInfiniteLifetime || deprecated { + // If addr is preferred forever or it has been deprecated already, there is + // no need for a deprecation timer. + addrState.deprecationTimer = nil + } else if addrState.deprecationTimer == nil { + // addr is now preferred for a finite lifetime. + addrState.deprecationTimer = ndp.autoGenAddrDeprecationTimer(addr, pl, addrState.doNotDeprecate) + } else { + // addr continues to be preferred for a finite lifetime. + addrState.deprecationTimer.Reset(pl) + } + + // As per RFC 4862 section 5.5.3.e, the valid lifetime of the address + // + // + // 1) If the received Valid Lifetime is greater than 2 hours or greater than + // RemainingLifetime, set the valid lifetime of the address to the + // advertised Valid Lifetime. + // + // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the + // advertised Valid Lifetime. + // + // 3) Otherwise, reset the valid lifetime of the address to 2 hours. + + // Handle the infinite valid lifetime separately as we do not keep a timer in + // this case. + if vl >= header.NDPInfiniteLifetime { + if addrState.invalidationTimer != nil { + // Valid lifetime was finite before, but now it is valid forever. + if !addrState.invalidationTimer.Stop() { + *addrState.doNotInvalidate = true + } + addrState.invalidationTimer = nil + addrState.validUntil = time.Time{} + } + + return + } + + var effectiveVl time.Duration + var rl time.Duration + + // If the address was originally set to be valid forever, assume the remaining + // time to be the maximum possible value. + if addrState.invalidationTimer == nil { + rl = header.NDPInfiniteLifetime + } else { + rl = time.Until(addrState.validUntil) + } + + if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { + effectiveVl = vl + } else if rl <= MinPrefixInformationValidLifetimeForUpdate { + return + } else { + effectiveVl = MinPrefixInformationValidLifetimeForUpdate + } + + if addrState.invalidationTimer == nil { + addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate) + } else { + if !addrState.invalidationTimer.Stop() { + *addrState.doNotInvalidate = true + } + addrState.invalidationTimer.Reset(effectiveVl) + } + + addrState.validUntil = time.Now().Add(effectiveVl) +} + +// notifyAutoGenAddressDeprecated notifies the stack's NDP dispatcher that addr +// has been deprecated. +func (ndp *ndpState) notifyAutoGenAddressDeprecated(addr tcpip.Address) { + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + }) + } +} + +// invalidateAutoGenAddress invalidates an auto-generated address. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) { + if !ndp.cleanupAutoGenAddrResourcesAndNotify(addr) { + return + } + + ndp.nic.removePermanentAddressLocked(addr) +} + +// cleanupAutoGenAddrResourcesAndNotify cleans up an invalidated auto-generated +// address's resources from ndp. If the stack has an NDP dispatcher, it will +// be notified that addr has been invalidated. +// +// Returns true if ndp had resources for addr to cleanup. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool { + state, ok := ndp.autoGenAddresses[addr] + + if !ok { + return false + } + + if state.deprecationTimer != nil { + state.deprecationTimer.Stop() + state.deprecationTimer = nil + *state.doNotDeprecate = true + } + + state.doNotDeprecate = nil + + if state.invalidationTimer != nil { + state.invalidationTimer.Stop() + state.invalidationTimer = nil + *state.doNotInvalidate = true + } + + state.doNotInvalidate = nil + + delete(ndp.autoGenAddresses, addr) + + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + }) + } + + return true +} + +// autoGenAddrDeprecationTimer returns a new deprecation timer for an +// auto-generated address that fires after pl. +// +// doNotDeprecate is used to inform the timer when it fires at the same time +// that an auto-generated address's preferred lifetime gets refreshed. See +// autoGenAddrState.doNotDeprecate for more details. +func (ndp *ndpState) autoGenAddrDeprecationTimer(addr tcpip.Address, pl time.Duration, doNotDeprecate *bool) *time.Timer { + return time.AfterFunc(pl, func() { + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if *doNotDeprecate { + *doNotDeprecate = false + return + } + + addrState, ok := ndp.autoGenAddresses[addr] + if !ok { + log.Fatalf("ndp: must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr) + } + addrState.ref.deprecated = true + ndp.notifyAutoGenAddressDeprecated(addr) + addrState.deprecationTimer = nil + ndp.autoGenAddresses[addr] = addrState + }) +} + +// autoGenAddrInvalidationTimer returns a new invalidation timer for an +// auto-generated address that fires after vl. +// +// doNotInvalidate is used to inform the timer when it fires at the same time +// that an auto-generated address's valid lifetime gets refreshed. See +// autoGenAddrState.doNotInvalidate for more details. +func (ndp *ndpState) autoGenAddrInvalidationTimer(addr tcpip.Address, vl time.Duration, doNotInvalidate *bool) *time.Timer { + return time.AfterFunc(vl, func() { + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if *doNotInvalidate { + *doNotInvalidate = false + return + } + + ndp.invalidateAutoGenAddress(addr) + }) +} + +// cleanupHostOnlyState cleans up any state that is only useful for hosts. +// +// cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a +// host to a router. This function will invalidate all discovered on-link +// prefixes, discovered routers, and auto-generated addresses as routers do not +// normally process Router Advertisements to discover default routers and +// on-link prefixes, and auto-generate addresses via SLAAC. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupHostOnlyState() { + for addr, _ := range ndp.autoGenAddresses { + ndp.invalidateAutoGenAddress(addr) + } + + if got := len(ndp.autoGenAddresses); got != 0 { + log.Fatalf("ndp: still have auto-generated addresses after cleaning up, found = %d", got) + } + + for prefix, _ := range ndp.onLinkPrefixes { + ndp.invalidateOnLinkPrefix(prefix) + } + + if got := len(ndp.onLinkPrefixes); got != 0 { + log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up, found = %d", got) + } + + for router, _ := range ndp.defaultRouters { + ndp.invalidateDefaultRouter(router) + } + + if got := len(ndp.defaultRouters); got != 0 { + log.Fatalf("ndp: still have discovered default routers after cleaning up, found = %d", got) + } +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 5b901f947..8d89859ba 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -21,6 +21,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -29,6 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) const ( @@ -38,15 +41,47 @@ const ( linkAddr1 = "\x02\x02\x03\x04\x05\x06" linkAddr2 = "\x02\x02\x03\x04\x05\x07" linkAddr3 = "\x02\x02\x03\x04\x05\x08" - defaultTimeout = 250 * time.Millisecond + defaultTimeout = 100 * time.Millisecond ) var ( llAddr1 = header.LinkLocalAddr(linkAddr1) llAddr2 = header.LinkLocalAddr(linkAddr2) llAddr3 = header.LinkLocalAddr(linkAddr3) + dstAddr = tcpip.FullAddress{ + Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + Port: 25, + } ) +func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix { + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return tcpip.AddressWithPrefix{} + } + + addrBytes := []byte(subnet.ID()) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + return tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: 64, + } +} + +// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent +// tcpip.Subnet, and an address where the lower half of the address is composed +// of the EUI-64 of linkAddr if it is a valid unicast ethernet address. +func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) { + prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0} + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address(prefixBytes), + PrefixLen: 64, + } + + subnet := prefix.Subnet() + + return prefix, subnet, addrForSubnet(subnet, linkAddr) +} + // TestDADDisabled tests that an address successfully resolves immediately // when DAD is not enabled (the default for an empty stack.Options). func TestDADDisabled(t *testing.T) { @@ -54,7 +89,7 @@ func TestDADDisabled(t *testing.T) { NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -103,6 +138,30 @@ type ndpPrefixEvent struct { discovered bool } +type ndpAutoGenAddrEventType int + +const ( + newAddr ndpAutoGenAddrEventType = iota + deprecatedAddr + invalidatedAddr +) + +type ndpAutoGenAddrEvent struct { + nicID tcpip.NICID + addr tcpip.AddressWithPrefix + eventType ndpAutoGenAddrEventType +} + +type ndpRDNSS struct { + addrs []tcpip.Address + lifetime time.Duration +} + +type ndpRDNSSEvent struct { + nicID tcpip.NICID + rdnss ndpRDNSS +} + var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) // ndpDispatcher implements NDPDispatcher so tests can know when various NDP @@ -113,7 +172,8 @@ type ndpDispatcher struct { rememberRouter bool prefixC chan ndpPrefixEvent rememberPrefix bool - routeTable []tcpip.Route + autoGenAddrC chan ndpAutoGenAddrEvent + rdnssC chan ndpRDNSSEvent } // Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. @@ -129,101 +189,95 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, add } // Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. -func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) { - if n.routerC != nil { - n.routerC <- ndpRouterEvent{ +func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { + if c := n.routerC; c != nil { + c <- ndpRouterEvent{ nicID, addr, true, } } - if !n.rememberRouter { - return false, nil - } - - rt := append([]tcpip.Route(nil), n.routeTable...) - rt = append(rt, tcpip.Route{ - Destination: header.IPv6EmptySubnet, - Gateway: addr, - NIC: nicID, - }) - n.routeTable = rt - return true, rt + return n.rememberRouter } // Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. -func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route { - if n.routerC != nil { - n.routerC <- ndpRouterEvent{ +func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { + if c := n.routerC; c != nil { + c <- ndpRouterEvent{ nicID, addr, false, } } - - var rt []tcpip.Route - exclude := tcpip.Route{ - Destination: header.IPv6EmptySubnet, - Gateway: addr, - NIC: nicID, - } - - for _, r := range n.routeTable { - if r != exclude { - rt = append(rt, r) - } - } - n.routeTable = rt - return rt } // Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered. -func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) (bool, []tcpip.Route) { - if n.prefixC != nil { - n.prefixC <- ndpPrefixEvent{ +func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { + if c := n.prefixC; c != nil { + c <- ndpPrefixEvent{ nicID, prefix, true, } } - if !n.rememberPrefix { - return false, nil - } - - rt := append([]tcpip.Route(nil), n.routeTable...) - rt = append(rt, tcpip.Route{ - Destination: prefix, - NIC: nicID, - }) - n.routeTable = rt - return true, rt + return n.rememberPrefix } // Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated. -func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) []tcpip.Route { - if n.prefixC != nil { - n.prefixC <- ndpPrefixEvent{ +func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) { + if c := n.prefixC; c != nil { + c <- ndpPrefixEvent{ nicID, prefix, false, } } +} - rt := make([]tcpip.Route, 0) - exclude := tcpip.Route{ - Destination: prefix, - NIC: nicID, +func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ + nicID, + addr, + newAddr, + } } + return true +} + +func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ + nicID, + addr, + deprecatedAddr, + } + } +} - for _, r := range n.routeTable { - if r != exclude { - rt = append(rt, r) +func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ + nicID, + addr, + invalidatedAddr, + } + } +} + +// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption. +func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { + if c := n.rdnssC; c != nil { + c <- ndpRDNSSEvent{ + nicID, + ndpRDNSS{ + addrs, + lifetime, + }, } } - n.routeTable = rt - return rt } // TestDADResolve tests that an address successfully resolves after performing @@ -246,7 +300,11 @@ func TestDADResolve(t *testing.T) { } for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } @@ -434,7 +492,7 @@ func TestDADFail(t *testing.T) { } opts.NDPConfigs.RetransmitTimer = time.Second * 2 - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -515,7 +573,7 @@ func TestDADStop(t *testing.T) { NDPConfigs: ndpConfigs, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -616,7 +674,7 @@ func TestSetNDPConfigurations(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPDisp: &ndpDisp, @@ -781,16 +839,33 @@ func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { // // Note, raBufWithPI does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink bool, vl uint32) tcpip.PacketBuffer { +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer { flags := uint8(0) if onLink { - flags |= 128 + // The OnLink flag is the 7th bit in the flags byte. + flags |= 1 << 7 + } + if auto { + // The Address Auto-Configuration flag is the 6th bit in the + // flags byte. + flags |= 1 << 6 } + // A valid header.NDPPrefixInformation must be 30 bytes. buf := [30]byte{} + // The first byte in a header.NDPPrefixInformation is the Prefix Length + // field. buf[0] = uint8(prefix.PrefixLen) + // The 2nd byte within a header.NDPPrefixInformation is the Flags field. buf[1] = flags + // The Valid Lifetime field starts after the 2nd byte within a + // header.NDPPrefixInformation. binary.BigEndian.PutUint32(buf[2:], vl) + // The Preferred Lifetime field starts after the 6th byte within a + // header.NDPPrefixInformation. + binary.BigEndian.PutUint32(buf[6:], pl) + // The Prefix Address field starts after the 14th byte within a + // header.NDPPrefixInformation. copy(buf[14:], prefix.Address) return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{ header.NDPPrefixInformation(buf[:]), @@ -812,10 +887,12 @@ func TestNoRouterDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), + routerC: make(chan ndpRouterEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -835,19 +912,27 @@ func TestNoRouterDiscovery(t *testing.T) { select { case <-ndpDisp.routerC: t.Fatal("unexpectedly discovered a router when configured not to") - case <-time.After(defaultTimeout): + default: } }) } } +// Check e to make sure that the event is for addr on nic with ID 1, and the +// discovered flag set to discovered. +func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { + return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) +} + // TestRouterDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered router when the dispatcher asks it not to. func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), + routerC: make(chan ndpRouterEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -861,59 +946,36 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { t.Fatalf("CreateNIC(1) = %s", err) } - routeTable := []tcpip.Route{ - { - header.IPv6EmptySubnet, - llAddr3, - 1, - }, - } - s.SetRouteTable(routeTable) - - // Rx an RA with short lifetime. - lifetime := time.Duration(1) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(lifetime))) + // Receive an RA for a router we should not remember. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) select { - case r := <-ndpDisp.routerC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != llAddr2 { - t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr2) + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr2, true); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for router discovery event") - } - - // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + default: + t.Fatal("expected router discovery event") } - // Wait for the normal invalidation time plus an extra second to - // make sure we do not actually receive any invalidation events as - // we should not have remembered the router in the first place. + // Wait for the invalidation time plus some buffer to make sure we do + // not actually receive any invalidation events as we should not have + // remembered the router in the first place. select { case <-ndpDisp.routerC: t.Fatal("should not have received any router events") - case <-time.After(lifetime*time.Second + defaultTimeout): - } - - // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): } } func TestRouterDiscovery(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), + routerC: make(chan ndpRouterEvent, 1), rememberRouter: true, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -923,22 +985,29 @@ func TestRouterDiscovery(t *testing.T) { NDPDisp: &ndpDisp, }) - waitForEvent := func(addr tcpip.Address, discovered bool, timeout time.Duration) { + expectRouterEvent := func(addr tcpip.Address, discovered bool) { t.Helper() select { - case r := <-ndpDisp.routerC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, discovered); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) - } - if r.discovered != discovered { - t.Fatalf("got r.discovered = %t, want = %t", r.discovered, discovered) + default: + t.Fatal("expected router discovery event") + } + } + + expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, false); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) } case <-time.After(timeout): - t.Fatal("timeout waiting for router discovery event") + t.Fatal("timed out waiting for router discovery event") } } @@ -952,27 +1021,17 @@ func TestRouterDiscovery(t *testing.T) { select { case <-ndpDisp.routerC: t.Fatal("unexpectedly discovered a router with 0 lifetime") - case <-time.After(defaultTimeout): + default: } // Rx an RA from lladdr2 with a huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - waitForEvent(llAddr2, true, defaultTimeout) - - // Should have a default route through the discovered router. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + expectRouterEvent(llAddr2, true) // Rx an RA from another router (lladdr3) with non-zero lifetime. l3Lifetime := time.Duration(6) e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime))) - waitForEvent(llAddr3, true, defaultTimeout) - - // Should have default routes through the discovered routers. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + expectRouterEvent(llAddr3, true) // Rx an RA from lladdr2 with lesser lifetime. l2Lifetime := time.Duration(2) @@ -980,12 +1039,7 @@ func TestRouterDiscovery(t *testing.T) { select { case <-ndpDisp.routerC: t.Fatal("Should not receive a router event when updating lifetimes for known routers") - case <-time.After(defaultTimeout): - } - - // Should still have a default route through the discovered routers. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + default: } // Wait for lladdr2's router invalidation timer to fire. The lifetime @@ -995,31 +1049,15 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - waitForEvent(llAddr2, false, l2Lifetime*time.Second+defaultTimeout) - - // Should no longer have the default route through lladdr2. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + expectAsyncRouterInvalidationEvent(llAddr2, l2Lifetime*time.Second+defaultTimeout) // Rx an RA from lladdr2 with huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - waitForEvent(llAddr2, true, defaultTimeout) - - // Should have a default route through the discovered routers. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}, {header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + expectRouterEvent(llAddr2, true) // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - waitForEvent(llAddr2, false, defaultTimeout) - - // Should have deleted the default route through the router that just - // got invalidated. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + expectRouterEvent(llAddr2, false) // Wait for lladdr3's router invalidation timer to fire. The lifetime // of the router should have been updated to the most recent (smaller) @@ -1028,23 +1066,19 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - waitForEvent(llAddr3, false, l3Lifetime*time.Second+defaultTimeout) - - // Should not have any routes now that all discovered routers have been - // invalidated. - if got := len(s.GetRouteTable()); got != 0 { - t.Fatalf("got len(s.GetRouteTable()) = %d, want = 0", got) - } + expectAsyncRouterInvalidationEvent(llAddr3, l3Lifetime*time.Second+defaultTimeout) } // TestRouterDiscoveryMaxRouters tests that only // stack.MaxDiscoveredDefaultRouters discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), + routerC: make(chan ndpRouterEvent, 1), rememberRouter: true, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1058,8 +1092,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { t.Fatalf("CreateNIC(1) = %s", err) } - expectedRt := [stack.MaxDiscoveredDefaultRouters]tcpip.Route{} - // Receive an RA from 2 more than the max number of discovered routers. for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ { linkAddr := []byte{2, 2, 3, 4, 5, 0} @@ -1069,36 +1101,23 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) if i <= stack.MaxDiscoveredDefaultRouters { - expectedRt[i-1] = tcpip.Route{header.IPv6EmptySubnet, llAddr, 1} select { - case r := <-ndpDisp.routerC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr, true); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) } - if r.addr != llAddr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr) - } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for router discovery event") + default: + t.Fatal("expected router discovery event") } } else { select { case <-ndpDisp.routerC: t.Fatal("should not have discovered a new router after we already discovered the max number of routers") - case <-time.After(defaultTimeout): + default: } } } - - // Should only have default routes for the first - // stack.MaxDiscoveredDefaultRouters discovered routers. - if got := s.GetRouteTable(); !cmp.Equal(got, expectedRt[:]) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt) - } } // TestNoPrefixDiscovery tests that prefix discovery will not be performed if @@ -1121,10 +1140,12 @@ func TestNoPrefixDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), + prefixC: make(chan ndpPrefixEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1140,30 +1161,34 @@ func TestNoPrefixDiscovery(t *testing.T) { } // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, 10)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly discovered a prefix when configured not to") - case <-time.After(defaultTimeout): + default: } }) } } +// Check e to make sure that the event is for prefix on nic with ID 1, and the +// discovered flag set to discovered. +func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { + return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) +} + // TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered on-link prefix when the dispatcher asks it not to. func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - subnet := prefix.Subnet() + t.Parallel() + + prefix, subnet, _ := prefixSubnetAddr(0, "") ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), + prefixC: make(chan ndpPrefixEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1178,75 +1203,40 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { t.Fatalf("CreateNIC(1) = %s", err) } - routeTable := []tcpip.Route{ - { - header.IPv6EmptySubnet, - llAddr3, - 1, - }, - } - s.SetRouteTable(routeTable) - - // Rx an RA with prefix with a short lifetime. - const lifetime = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, lifetime)) + // Receive an RA with prefix that we should not remember. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.prefix != subnet { - t.Fatalf("got r.prefix = %s, want = %s", r.prefix, subnet) + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, true); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for prefix discovery event") - } - - // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + default: + t.Fatal("expected prefix discovery event") } - // Wait for the normal invalidation time plus some buffer to - // make sure we do not actually receive any invalidation events as - // we should not have remembered the prefix in the first place. + // Wait for the invalidation time plus some buffer to make sure we do + // not actually receive any invalidation events as we should not have + // remembered the prefix in the first place. select { case <-ndpDisp.prefixC: t.Fatal("should not have received any prefix events") - case <-time.After(lifetime*time.Second + defaultTimeout): - } - - // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): } } func TestPrefixDiscovery(t *testing.T) { - prefix1 := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - prefix2 := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x09\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - prefix3 := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x09\x0a\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 72, - } - subnet1 := prefix1.Subnet() - subnet2 := prefix2.Subnet() - subnet3 := prefix3.Subnet() + t.Parallel() + + prefix1, subnet1, _ := prefixSubnetAddr(0, "") + prefix2, subnet2, _ := prefixSubnetAddr(1, "") + prefix3, subnet3, _ := prefixSubnetAddr(2, "") ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), + prefixC: make(chan ndpPrefixEvent, 1), rememberPrefix: true, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1256,106 +1246,72 @@ func TestPrefixDiscovery(t *testing.T) { NDPDisp: &ndpDisp, }) - waitForEvent := func(subnet tcpip.Subnet, discovered bool, timeout time.Duration) { + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { t.Helper() select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.prefix != subnet { - t.Fatalf("got r.prefix = %s, want = %s", r.prefix, subnet) + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - if r.discovered != discovered { - t.Fatalf("got r.discovered = %t, want = %t", r.discovered, discovered) - } - case <-time.After(timeout): - t.Fatal("timeout waiting for prefix discovery event") + default: + t.Fatal("expected prefix discovery event") } } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly discovered a prefix with 0 lifetime") - case <-time.After(defaultTimeout): + default: } // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, 100)) - waitForEvent(subnet1, true, defaultTimeout) - - // Should have added a device route for subnet1 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) + expectPrefixEvent(subnet1, true) // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, 100)) - waitForEvent(subnet2, true, defaultTimeout) - - // Should have added a device route for subnet2 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) + expectPrefixEvent(subnet2, true) // Receive an RA with prefix3 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, 100)) - waitForEvent(subnet3, true, defaultTimeout) - - // Should have added a device route for subnet3 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) + expectPrefixEvent(subnet3, true) // Receive an RA with prefix1 in a PI with lifetime = 0. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, 0)) - waitForEvent(subnet1, false, defaultTimeout) - - // Should have removed the device route for subnet1 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) - } + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + expectPrefixEvent(subnet1, false) // Receive an RA with prefix2 in a PI with lesser lifetime. lifetime := uint32(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, lifetime)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly received prefix event when updating lifetime") - case <-time.After(defaultTimeout): - } - - // Should not have updated route table. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + default: } // Wait for prefix2's most recent invalidation timer plus some buffer to // expire. - waitForEvent(subnet2, false, time.Duration(lifetime)*time.Second+defaultTimeout) - - // Should have removed the device route for subnet2 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet2, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(lifetime)*time.Second + defaultTimeout): + t.Fatal("timed out waiting for prefix discovery event") } // Receive RA to invalidate prefix3. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, 0)) - waitForEvent(subnet3, false, defaultTimeout) - - // Should not have any routes. - if got := len(s.GetRouteTable()); got != 0 { - t.Fatalf("got len(s.GetRouteTable()) = %d, want = 0", got) - } + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) + expectPrefixEvent(subnet3, false) } func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { @@ -1364,10 +1320,10 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // invalidate the prefix. const testInfiniteLifetimeSeconds = 2 const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second - saved := header.NDPPrefixInformationInfiniteLifetime - header.NDPPrefixInformationInfiniteLifetime = testInfiniteLifetime + saved := header.NDPInfiniteLifetime + header.NDPInfiniteLifetime = testInfiniteLifetime defer func() { - header.NDPPrefixInformationInfiniteLifetime = saved + header.NDPInfiniteLifetime = saved }() prefix := tcpip.AddressWithPrefix{ @@ -1377,10 +1333,10 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { subnet := prefix.Subnet() ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), + prefixC: make(chan ndpPrefixEvent, 1), rememberPrefix: true, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1390,33 +1346,27 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { NDPDisp: &ndpDisp, }) - waitForEvent := func(discovered bool, timeout time.Duration) { + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { t.Helper() select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Errorf("got r.nicID = %d, want = 1", r.nicID) - } - if r.prefix != subnet { - t.Errorf("got r.prefix = %s, want = %s", r.prefix, subnet) + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - if r.discovered != discovered { - t.Errorf("got r.discovered = %t, want = %t", r.discovered, discovered) - } - case <-time.After(timeout): - t.Fatal("timeout waiting for prefix discovery event") + default: + t.Fatal("expected prefix discovery event") } } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - // Receive an RA with prefix in an NDP Prefix Information option (PI) // with infinite valid lifetime which should not get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds)) - waitForEvent(true, defaultTimeout) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + expectPrefixEvent(subnet, true) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") @@ -1425,16 +1375,23 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with finite lifetime. // The prefix should get invalidated after 1s. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds-1)) - waitForEvent(false, testInfiniteLifetime) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(testInfiniteLifetime): + t.Fatal("timed out waiting for prefix discovery event") + } // Receive an RA with finite lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds-1)) - waitForEvent(true, defaultTimeout) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + expectPrefixEvent(subnet, true) // Receive an RA with prefix with an infinite lifetime. // The prefix should not be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") @@ -1443,7 +1400,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with a prefix with a lifetime value greater than the // set infinite lifetime value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds+1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") @@ -1452,18 +1409,20 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with 0 lifetime. // The prefix should get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, 0)) - waitForEvent(false, defaultTimeout) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0)) + expectPrefixEvent(subnet, false) } // TestPrefixDiscoveryMaxRouters tests that only // stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), rememberPrefix: true, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1479,7 +1438,6 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { } optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2) - expectedRt := [stack.MaxDiscoveredOnLinkPrefixes]tcpip.Route{} prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} // Receive an RA with 2 more than the max number of discovered on-link @@ -1499,41 +1457,1500 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { copy(buf[14:], prefix.Address) optSer[i] = header.NDPPrefixInformation(buf[:]) - - if i < stack.MaxDiscoveredOnLinkPrefixes { - expectedRt[i] = tcpip.Route{prefixes[i], tcpip.Address([]byte(nil)), 1} - } } e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { if i < stack.MaxDiscoveredOnLinkPrefixes { select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.prefix != prefixes[i] { - t.Fatalf("got r.prefix = %s, want = %s", r.prefix, prefixes[i]) + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for prefix discovery event") + default: + t.Fatal("expected prefix discovery event") } } else { select { case <-ndpDisp.prefixC: t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes") + default: + } + } + } +} + +// Checks to see if list contains an IPv6 address, item. +func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { + protocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: item, + } + + for _, i := range list { + if i == protocolAddress { + return true + } + } + + return false +} + +// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. +func TestNoAutoGenAddr(t *testing.T) { + prefix, _, _ := prefixSubnetAddr(0, "") + + // Being configured to auto-generate addresses means handle and + // autogen are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, autogen = + // true and forwarding = false (the required configuration to do + // SLAAC) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + autogen := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + AutoGenGlobalAddresses: autogen, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with prefix with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) + + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when configured not to") + default: + } + }) + } +} + +// Check e to make sure that the event is for addr on nic with ID 1, and the +// event type is set to eventType. +func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { + return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) +} + +// TestAutoGenAddr tests that an address is properly generated and invalidated +// when configured to do so. +func TestAutoGenAddr(t *testing.T) { + const newMinVL = 2 + newMinVLDuration := newMinVL * time.Second + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with 0 lifetime") + default: + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") + default: + } + + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } + + // Refresh valid lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") + default: + } + + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(newMinVLDuration + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } +} + +// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher, +// channel.Endpoint and stack.Stack. +// +// stack.Stack will have a default route through the router (llAddr3) installed +// and a static link-address (linkAddr3) added to the link address cache for the +// router. +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { + t.Helper() + ndpDisp := &ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: ndpDisp, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: llAddr3, + NIC: nicID, + }}) + s.AddLinkAddress(nicID, llAddr3, linkAddr3) + return ndpDisp, e, s +} + +// addrForNewConnection returns the local address used when creating a new +// connection. +func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + v := tcpip.V6OnlyOption(1) + if err := ep.SetSockOpt(v); err != nil { + t.Fatalf("SetSockOpt(%+v): %s", v, err) + } + if err := ep.Connect(dstAddr); err != nil { + t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) + } + got, err := ep.GetLocalAddress() + if err != nil { + t.Fatalf("ep.GetLocalAddress(): %s", err) + } + return got.Addr +} + +// addrForNewConnectionWithAddr returns the local address used when creating a +// new connection with a specific local address. +func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + v := tcpip.V6OnlyOption(1) + if err := ep.SetSockOpt(v); err != nil { + t.Fatalf("SetSockOpt(%+v): %s", v, err) + } + if err := ep.Bind(addr); err != nil { + t.Fatalf("ep.Bind(%+v): %s", addr, err) + } + if err := ep.Connect(dstAddr); err != nil { + t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) + } + got, err := ep.GetLocalAddress() + if err != nil { + t.Fatalf("ep.GetLocalAddress(): %s", err) + } + return got.Addr +} + +// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when +// receiving a PI with 0 preferred lifetime. +func TestAutoGenAddrDeprecateFromPI(t *testing.T) { + const nicID = 1 + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() + + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } + + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } + + // Receive PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + expectPrimaryAddr(addr1) + + // Deprecate addr for prefix1 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + // addr should still be the primary endpoint as there are no other addresses. + expectPrimaryAddr(addr1) + + // Refresh lifetimes of addr generated from prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) + + // Deprecate addr for prefix2 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr1 should be the primary endpoint now since addr2 is deprecated but + // addr1 is not. + expectPrimaryAddr(addr1) + // addr2 is deprecated but if explicitly requested, it should be used. + fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + } + + // Another PI w/ 0 preferred lifetime should not result in a deprecation + // event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + } + + // Refresh lifetimes of addr generated from prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) +} + +// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated +// when its preferred lifetime expires. +func TestAutoGenAddrTimerDeprecation(t *testing.T) { + const nicID = 1 + const newMinVL = 2 + newMinVLDuration := newMinVL * time.Second + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() + + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } + + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } + + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) + + // Receive a PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr1) + + // Refresh lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since addr1 is deprecated but + // addr2 is not. + expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. + fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + } + + // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make + // sure we do not get a deprecation event again. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + } + + // Refresh lifetimes for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + // addr1 is the primary endpoint again since it is non-deprecated now. + expectPrimaryAddr(addr1) + + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since it is not deprecated. + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + } + + // Wait for addr of prefix1 to be invalidated. + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultTimeout) + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) + + // Refresh both lifetimes for addr of prefix2 to the same value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + + // Wait for a deprecation then invalidation events, or just an invalidation + // event. We need to cover both cases but cannot deterministically hit both + // cases because the deprecation and invalidation handlers could be handled in + // either deprecation then invalidation, or invalidation then deprecation + // (which should be cancelled by the invalidation handler). + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { + // If we get a deprecation event first, we should get an invalidation + // event almost immediately after. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { + // If we get an invalidation event first, we should not get a deprecation + // event after. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") case <-time.After(defaultTimeout): } + } else { + t.Fatalf("got unexpected auto-generated event") } + + case <-time.After(newMinVLDuration + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) } + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should not have %s in the list of addresses", addr2) + } + // Should not have any primary endpoints. + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); got != want { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) + } + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + v := tcpip.V6OnlyOption(1) + if err := ep.SetSockOpt(v); err != nil { + t.Fatalf("SetSockOpt(%+v): %s", v, err) + } + + if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { + t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute) + } +} - // Should only have device routes for the first - // stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes. - if got := s.GetRouteTable(); !cmp.Equal(got, expectedRt[:]) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt) +// Tests transitioning a SLAAC address's valid lifetime between finite and +// infinite values. +func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { + const infiniteVLSeconds = 2 + const minVLSeconds = 1 + savedIL := header.NDPInfiniteLifetime + savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL + header.NDPInfiniteLifetime = savedIL + }() + stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second + header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + tests := []struct { + name string + infiniteVL uint32 + }{ + { + name: "EqualToInfiniteVL", + infiniteVL: infiniteVLSeconds, + }, + // Our implementation supports changing header.NDPInfiniteLifetime for tests + // such that a packet can be received where the lifetime field has a value + // greater than header.NDPInfiniteLifetime. Because of this, we test to make + // sure that receiving a value greater than header.NDPInfiniteLifetime is + // handled the same as when receiving a value equal to + // header.NDPInfiniteLifetime. + { + name: "MoreThanInfiniteVL", + infiniteVL: infiniteVLSeconds + 1, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with finite prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + default: + t.Fatal("expected addr auto gen event") + } + + // Receive an new RA with prefix with infinite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) + + // Receive a new RA with prefix with finite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + case <-time.After(minVLSeconds*time.Second + defaultTimeout): + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } + }) +} + +// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an +// auto-generated address only gets updated when required to, as specified in +// RFC 4862 section 5.5.3.e. +func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { + const infiniteVL = 4294967295 + const newMinVL = 4 + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + tests := []struct { + name string + ovl uint32 + nvl uint32 + evl uint32 + }{ + // Should update the VL to the minimum VL for updating if the + // new VL is less than newMinVL but was originally greater than + // it. + { + "LargeVLToVLLessThanMinVLForUpdate", + 9999, + 1, + newMinVL, + }, + { + "LargeVLTo0", + 9999, + 0, + newMinVL, + }, + { + "InfiniteVLToVLLessThanMinVLForUpdate", + infiniteVL, + 1, + newMinVL, + }, + { + "InfiniteVLTo0", + infiniteVL, + 0, + newMinVL, + }, + + // Should not update VL if original VL was less than newMinVL + // and the new VL is also less than newMinVL. + { + "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", + newMinVL - 1, + newMinVL - 3, + newMinVL - 1, + }, + + // Should take the new VL if the new VL is greater than the + // remaining time or is greater than newMinVL. + { + "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", + newMinVL + 5, + newMinVL + 3, + newMinVL + 3, + }, + { + "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", + newMinVL - 3, + newMinVL - 1, + newMinVL - 1, + }, + { + "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", + newMinVL - 3, + newMinVL + 1, + newMinVL + 1, + }, + } + + const delta = 500 * time.Millisecond + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix with initial VL, + // test.ovl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + + // Receive an new RA with prefix with new VL, + // test.nvl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) + + // + // Validate that the VL for the address got set + // to test.evl. + // + + // Make sure we do not get any invalidation + // events until atleast 500ms (delta) before + // test.evl. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + case <-time.After(time.Duration(test.evl)*time.Second - delta): + } + + // Wait for another second (2x delta), but now + // we expect the invalidation event. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + case <-time.After(2 * delta): + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } + }) +} + +// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed +// by the user, its resources will be cleaned up and an invalidation event will +// be sent to the integrator. +func TestAutoGenAddrRemoval(t *testing.T) { + t.Parallel() + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive a PI to auto-generate an address. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr, newAddr) + + // Removing the address should result in an invalidation event + // immediately. + if err := s.RemoveAddress(1, addr.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + + // Wait for the original valid lifetime to make sure the original timer + // got stopped/cleaned up. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + } +} + +// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that +// is already assigned to the NIC, the static address remains. +func TestAutoGenAddrStaticConflict(t *testing.T) { + t.Parallel() + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Add the address as a static address before SLAAC tries to add it. + if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Receive a PI where the generated address will be the same as the one + // that we already have assigned statically. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") + default: + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Should not get an invalidation event after the PI's invalidation + // time. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } +} + +// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use +// opaque interface identifiers when configured to do so. +func TestAutoGenAddrWithOpaqueIID(t *testing.T) { + t.Parallel() + + const nicID = 1 + const nicName = "nic1" + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte + secretKey := secretKeyBuf[:] + n, err := rand.Read(secretKey) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) + } + + prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1) + prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1) + // addr1 and addr2 are the addresses that are expected to be generated when + // stack.Stack is configured to generate opaque interface identifiers as + // defined by RFC 7217. + addrBytes := []byte(subnet1.ID()) + addr1 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)), + PrefixLen: 64, + } + addrBytes = []byte(subnet2.ID()) + addr2 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)), + PrefixLen: 64, + } + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + }) + + if err := s.CreateNamedNIC(nicID, nicName, e); err != nil { + t.Fatalf("CreateNamedNIC(%d, %q, _) = %s", nicID, nicName, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive an RA with prefix1 in a PI. + const validLifetimeSecondPrefix1 = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + + // Receive an RA with prefix2 in a PI with a large valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } +} + +// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event +// to the integrator when an RA is received with the NDP Recursive DNS Server +// option with at least one valid address. +func TestNDPRecursiveDNSServerDispatch(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opt header.NDPRecursiveDNSServer + expected *ndpRDNSS + }{ + { + "Unspecified", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }), + nil, + }, + { + "Multicast", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + }), + nil, + }, + { + "OptionTooSmall", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, + }), + nil, + }, + { + "0Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + }), + nil, + }, + { + "Valid1Address", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + }, + 2 * time.Second, + }, + }, + { + "Valid2Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", + }, + time.Second, + }, + }, + { + "Valid3Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03", + }, + 0, + }, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + // We do not expect more than a single RDNSS + // event at any time for this test. + rdnssC: make(chan ndpRDNSSEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt})) + + if test.expected != nil { + select { + case e := <-ndpDisp.rdnssC: + if e.nicID != 1 { + t.Errorf("got rdnss nicID = %d, want = 1", e.nicID) + } + if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" { + t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff) + } + if e.rdnss.lifetime != test.expected.lifetime { + t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime) + } + default: + t.Fatal("expected an RDNSS option event") + } + } + + // Should have no more RDNSS options. + select { + case e := <-ndpDisp.rdnssC: + t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e) + default: + } + }) + } +} + +// TestCleanupHostOnlyStateOnBecomingRouter tests that all discovered routers +// and prefixes, and auto-generated addresses get invalidated when a NIC +// becomes a router. +func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { + t.Parallel() + + const ( + lifetimeSeconds = 5 + maxEvents = 4 + nicID1 = 1 + nicID2 = 2 + ) + + prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1) + e2Addr1 := addrForSubnet(subnet1, linkAddr2) + e2Addr2 := addrForSubnet(subnet2, linkAddr2) + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, maxEvents), + rememberRouter: true, + prefixC: make(chan ndpPrefixEvent, maxEvents), + rememberPrefix: true, + autoGenAddrC: make(chan ndpAutoGenAddrEvent, maxEvents), + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + e1 := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) + } + + e2 := channel.New(0, 1280, linkAddr2) + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) + } + + expectRouterEvent := func() (bool, ndpRouterEvent) { + select { + case e := <-ndpDisp.routerC: + return true, e + default: + } + + return false, ndpRouterEvent{} + } + + expectPrefixEvent := func() (bool, ndpPrefixEvent) { + select { + case e := <-ndpDisp.prefixC: + return true, e + default: + } + + return false, ndpPrefixEvent{} + } + + expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { + select { + case e := <-ndpDisp.autoGenAddrC: + return true, e + default: + } + + return false, ndpAutoGenAddrEvent{} + } + + // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr1 and + // llAddr2) w/ PI (for prefix1 in RA from llAddr1 and prefix2 in RA from + // llAddr2) to discover multiple routers and prefixes, and auto-gen + // multiple addresses. + + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr1, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + // We have other tests that make sure we receive the *correct* events + // on normal discovery of routers/prefixes, and auto-generated + // addresses. Here we just make sure we get an event and let other tests + // handle the correctness check. + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr1, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) + } + + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr2, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) + } + + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr1, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr1, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) + } + + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr2, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) + } + + // We should have the auto-generated addresses added. + nicinfo := s.NICInfo() + nic1Addrs := nicinfo[nicID1].ProtocolAddresses + nic2Addrs := nicinfo[nicID2].ProtocolAddresses + if !contains(nic1Addrs, e1Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if !contains(nic1Addrs, e1Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if !contains(nic2Addrs, e2Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if !contains(nic2Addrs, e2Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } + + // We can't proceed any further if we already failed the test (missing + // some discovery/auto-generated address events or addresses). + if t.Failed() { + t.FailNow() + } + + s.SetForwarding(true) + + // Collect invalidation events after becoming a router + gotRouterEvents := make(map[ndpRouterEvent]int) + for i := 0; i < maxEvents; i++ { + ok, e := expectRouterEvent() + if !ok { + t.Errorf("expected %d router events after becoming a router; got = %d", maxEvents, i) + break + } + gotRouterEvents[e]++ + } + gotPrefixEvents := make(map[ndpPrefixEvent]int) + for i := 0; i < maxEvents; i++ { + ok, e := expectPrefixEvent() + if !ok { + t.Errorf("expected %d prefix events after becoming a router; got = %d", maxEvents, i) + break + } + gotPrefixEvents[e]++ + } + gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) + for i := 0; i < maxEvents; i++ { + ok, e := expectAutoGenAddrEvent() + if !ok { + t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", maxEvents, i) + break + } + gotAutoGenAddrEvents[e]++ + } + + // No need to proceed any further if we already failed the test (missing + // some invalidation events). + if t.Failed() { + t.FailNow() + } + + expectedRouterEvents := map[ndpRouterEvent]int{ + {nicID: nicID1, addr: llAddr1, discovered: false}: 1, + {nicID: nicID1, addr: llAddr2, discovered: false}: 1, + {nicID: nicID2, addr: llAddr1, discovered: false}: 1, + {nicID: nicID2, addr: llAddr2, discovered: false}: 1, + } + if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { + t.Errorf("router events mismatch (-want +got):\n%s", diff) + } + expectedPrefixEvents := map[ndpPrefixEvent]int{ + {nicID: nicID1, prefix: subnet1, discovered: false}: 1, + {nicID: nicID1, prefix: subnet2, discovered: false}: 1, + {nicID: nicID2, prefix: subnet1, discovered: false}: 1, + {nicID: nicID2, prefix: subnet2, discovered: false}: 1, + } + if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { + t.Errorf("prefix events mismatch (-want +got):\n%s", diff) + } + expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ + {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, + } + if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { + t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) + } + + // Make sure the auto-generated addresses got removed. + nicinfo = s.NICInfo() + nic1Addrs = nicinfo[nicID1].ProtocolAddresses + nic2Addrs = nicinfo[nicID2].ProtocolAddresses + if contains(nic1Addrs, e1Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if contains(nic1Addrs, e1Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if contains(nic2Addrs, e2Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if contains(nic2Addrs, e2Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } + + // Should not get any more events (invalidation timers should have been + // cancelled when we transitioned into a router). + time.Sleep(lifetimeSeconds*time.Second + defaultTimeout) + select { + case <-ndpDisp.routerC: + t.Error("unexpected router event") + default: + } + select { + case <-ndpDisp.prefixC: + t.Error("unexpected prefix event") + default: + } + select { + case <-ndpDisp.autoGenAddrC: + t.Error("unexpected auto-generated address event") + default: } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 3f8d7312c..4144d5d0f 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -27,11 +27,10 @@ import ( // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { - stack *Stack - id tcpip.NICID - name string - linkEP LinkEndpoint - loopback bool + stack *Stack + id tcpip.NICID + name string + linkEP LinkEndpoint mu sync.RWMutex spoofing bool @@ -85,7 +84,7 @@ const ( ) // newNIC returns a new NIC using the default NDP configurations from stack. -func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC { +func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For // example, make sure that the link address it provides is a valid // unicast ethernet address. @@ -99,7 +98,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback id: id, name: name, linkEP: ep, - loopback: loopback, primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), @@ -115,10 +113,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback }, }, ndp: ndpState{ - configs: stack.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + autoGenAddresses: make(map[tcpip.Address]autoGenAddressState), }, } nic.ndp.nic = nic @@ -173,23 +172,28 @@ func (n *NIC) enable() *tcpip.Error { return err } - if !n.stack.autoGenIPv6LinkLocal { + // Do not auto-generate an IPv6 link-local address for loopback devices. + if !n.stack.autoGenIPv6LinkLocal || n.isLoopback() { return nil } - l2addr := n.linkEP.LinkAddress() + var addr tcpip.Address + if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey) + } else { + l2addr := n.linkEP.LinkAddress() - // Only attempt to generate the link-local address if we have a - // valid MAC address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address - // (provided by LinkEndpoint.LinkAddress) before reaching this - // point. - if !header.IsValidUnicastEthernetAddress(l2addr) { - return nil - } + // Only attempt to generate the link-local address if we have a valid MAC + // address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + if !header.IsValidUnicastEthernetAddress(l2addr) { + return nil + } - addr := header.LinkLocalAddr(l2addr) + addr = header.LinkLocalAddr(l2addr) + } _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ Protocol: header.IPv6ProtocolNumber, @@ -202,6 +206,18 @@ func (n *NIC) enable() *tcpip.Error { return err } +// becomeIPv6Router transitions n into an IPv6 router. +// +// When transitioning into an IPv6 router, host-only state (NDP discovered +// routers, discovered on-link prefixes, and auto-generated addresses) will +// be cleaned up/invalidated. +func (n *NIC) becomeIPv6Router() { + n.mu.Lock() + defer n.mu.Unlock() + + n.ndp.cleanupHostOnlyState() +} + // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it // to start delivering packets. func (n *NIC) attachLinkEndpoint() { @@ -222,6 +238,10 @@ func (n *NIC) isPromiscuousMode() bool { return rv } +func (n *NIC) isLoopback() bool { + return n.linkEP.Capabilities()&CapabilityLoopback != 0 +} + // setSpoofing enables or disables address spoofing. func (n *NIC) setSpoofing(enable bool) { n.mu.Lock() @@ -231,17 +251,61 @@ func (n *NIC) setSpoofing(enable bool) { // primaryEndpoint returns the primary endpoint of n for the given network // protocol. +// +// primaryEndpoint will return the first non-deprecated endpoint if such an +// endpoint exists. If no non-deprecated endpoint exists, the first deprecated +// endpoint will be returned. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { n.mu.RLock() defer n.mu.RUnlock() + var deprecatedEndpoint *referencedNetworkEndpoint for _, r := range n.primary[protocol] { - if r.isValidForOutgoing() && r.tryIncRef() { - return r + if !r.isValidForOutgoing() { + continue + } + + if !r.deprecated { + if r.tryIncRef() { + // r is not deprecated, so return it immediately. + // + // If we kept track of a deprecated endpoint, decrement its reference + // count since it was incremented when we decided to keep track of it. + if deprecatedEndpoint != nil { + deprecatedEndpoint.decRefLocked() + deprecatedEndpoint = nil + } + + return r + } + } else if deprecatedEndpoint == nil && r.tryIncRef() { + // We prefer an endpoint that is not deprecated, but we keep track of r in + // case n doesn't have any non-deprecated endpoints. + // + // If we end up finding a more preferred endpoint, r's reference count + // will be decremented when such an endpoint is found. + deprecatedEndpoint = r } } - return nil + // n doesn't have any valid non-deprecated endpoints, so return + // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated + // endpoints either). + return deprecatedEndpoint +} + +// hasPermanentAddrLocked returns true if n has a permanent (including currently +// tentative) address, addr. +func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { + ref, ok := n.endpoints[NetworkEndpointID{addr}] + + if !ok { + return false + } + + kind := ref.getKind() + + return kind == permanent || kind == permanentTentative } func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { @@ -335,7 +399,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t Address: address, PrefixLen: netProto.DefaultPrefixLen(), }, - }, peb, temporary) + }, peb, temporary, static, false) n.mu.Unlock() return ref @@ -384,10 +448,10 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } } - return n.addAddressLocked(protocolAddress, peb, permanent) + return n.addAddressLocked(protocolAddress, peb, permanent, static, false) } -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { // TODO(b/141022673): Validate IP address before adding them. // Sanity check. @@ -417,11 +481,13 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - kind: kind, + refs: 1, + ep: ep, + nic: n, + protocol: protocolAddress.Protocol, + kind: kind, + configType: configType, + deprecated: deprecated, } // Set up cache if link address resolution exists for this protocol. @@ -520,6 +586,51 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { return addrs } +// primaryAddress returns the primary address associated with this NIC. +// +// primaryAddress will return the first non-deprecated address if such an +// address exists. If no non-deprecated address exists, the first deprecated +// address will be returned. +func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { + n.mu.RLock() + defer n.mu.RUnlock() + + list, ok := n.primary[proto] + if !ok { + return tcpip.AddressWithPrefix{} + } + + var deprecatedEndpoint *referencedNetworkEndpoint + for _, ref := range list { + // Don't include tentative, expired or tempory endpoints to avoid confusion + // and prevent the caller from using those. + switch ref.getKind() { + case permanentTentative, permanentExpired, temporary: + continue + } + + if !ref.deprecated { + return tcpip.AddressWithPrefix{ + Address: ref.ep.ID().LocalAddress, + PrefixLen: ref.ep.PrefixLen(), + } + } + + if deprecatedEndpoint == nil { + deprecatedEndpoint = ref + } + } + + if deprecatedEndpoint != nil { + return tcpip.AddressWithPrefix{ + Address: deprecatedEndpoint.ep.ID().LocalAddress, + PrefixLen: deprecatedEndpoint.ep.PrefixLen(), + } + } + + return tcpip.AddressWithPrefix{} +} + // AddAddressRange adds a range of addresses to n, so that it starts accepting // packets targeted at the given addresses and network protocol. The range is // given by a subnet address, and all addresses contained in the subnet are @@ -624,9 +735,18 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) - // If we are removing a tentative IPv6 unicast address, stop DAD. - if isIPv6Unicast && kind == permanentTentative { - n.ndp.stopDuplicateAddressDetection(addr) + if isIPv6Unicast { + // If we are removing a tentative IPv6 unicast address, stop + // DAD. + if kind == permanentTentative { + n.ndp.stopDuplicateAddressDetection(addr) + } + + // If we are removing an address generated via SLAAC, cleanup + // its SLAAC resources and notify the integrator. + if r.configType == slaac { + n.ndp.cleanupAutoGenAddrResourcesAndNotify(addr) + } } r.setKind(permanentExpired) @@ -989,7 +1109,7 @@ const ( // removing the permanent address from the NIC. permanent - // An expired permanent endoint is a permanent endoint that had its address + // An expired permanent endpoint is a permanent endpoint that had its address // removed from the NIC, and it is waiting to be removed once no more routes // hold a reference to it. This is achieved by decreasing its reference count // by 1. If its address is re-added before the endpoint is removed, its type @@ -1035,6 +1155,19 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep } } +type networkEndpointConfigType int32 + +const ( + // A statically configured endpoint is an address that was added by + // some user-specified action (adding an explicit address, joining a + // multicast group). + static networkEndpointConfigType = iota + + // A slaac configured endpoint is an IPv6 endpoint that was + // added by SLAAC as per RFC 4862 section 5.5.3. + slaac +) + type referencedNetworkEndpoint struct { ep NetworkEndpoint nic *NIC @@ -1050,6 +1183,15 @@ type referencedNetworkEndpoint struct { // networkEndpointKind must only be accessed using {get,set}Kind(). kind networkEndpointKind + + // configType is the method that was used to configure this endpoint. + // This must never change after the endpoint is added to a NIC. + configType networkEndpointConfigType + + // deprecated indicates whether or not the endpoint should be considered + // deprecated. That is, when deprecated is true, other endpoints that are not + // deprecated should be preferred. + deprecated bool } func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 61fd46d66..2b8751d49 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -234,15 +234,15 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. - WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 34307ae07..517f4b941 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -158,7 +158,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.Pack return tcpip.ErrInvalidEndpointState } - err := r.ref.ep.WritePacket(r, gso, params, r.Loop, pkt) + err := r.ref.ep.WritePacket(r, gso, params, pkt) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { @@ -174,7 +174,7 @@ func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params Network return 0, tcpip.ErrInvalidEndpointState } - n, err := r.ref.ep.WritePackets(r, gso, pkts, params, r.Loop) + n, err := r.ref.ep.WritePackets(r, gso, pkts, params) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n)) } @@ -195,7 +195,7 @@ func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - if err := r.ref.ep.WriteHeaderIncludedPacket(r, r.Loop, pkt); err != nil { + if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0e88643a4..807f910f6 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -352,6 +352,38 @@ func (u *uniqueIDGenerator) UniqueID() uint64 { return atomic.AddUint64((*uint64)(u), 1) } +// NICNameFromID is a function that returns a stable name for the specified NIC, +// even if different NIC IDs are used to refer to the same NIC in different +// program runs. It is used when generating opaque interface identifiers (IIDs). +// If the NIC was created with a name, it will be passed to NICNameFromID. +// +// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are +// generated for the same prefix on differnt NICs. +type NICNameFromID func(tcpip.NICID, string) string + +// OpaqueInterfaceIdentifierOptions holds the options related to the generation +// of opaque interface indentifiers (IIDs) as defined by RFC 7217. +type OpaqueInterfaceIdentifierOptions struct { + // NICNameFromID is a function that returns a stable name for a specified NIC, + // even if the NIC ID changes over time. + // + // Must be specified to generate the opaque IID. + NICNameFromID NICNameFromID + + // SecretKey is a pseudo-random number used as the secret key when generating + // opaque IIDs as defined by RFC 7217. The key SHOULD be at least + // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness + // requirements for security as outlined by RFC 4086. SecretKey MUST NOT + // change between program runs, unless explicitly changed. + // + // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey + // MUST NOT be modified after Stack is created. + // + // May be nil, but a nil value is highly discouraged to maintain + // some level of randomness between nodes. + SecretKey []byte +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -412,8 +444,8 @@ type Stack struct { ndpConfigs NDPConfigurations // autoGenIPv6LinkLocal determines whether or not the stack will attempt - // to auto-generate an IPv6 link-local address for newly enabled NICs. - // See the AutoGenIPv6LinkLocal field of Options for more details. + // to auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. autoGenIPv6LinkLocal bool // ndpDisp is the NDP event dispatcher that is used to send the netstack @@ -422,6 +454,10 @@ type Stack struct { // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID + + // opaqueIIDOpts hold the options for generating opaque interface identifiers + // (IIDs) as outlined by RFC 7217. + opaqueIIDOpts OpaqueInterfaceIdentifierOptions } // UniqueID is an abstract generator of unique identifiers. @@ -460,13 +496,15 @@ type Options struct { // before assigning an address to a NIC. NDPConfigs NDPConfigurations - // AutoGenIPv6LinkLocal determins whether or not the stack will attempt - // to auto-generate an IPv6 link-local address for newly enabled NICs. + // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to + // auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. + // // Note, setting this to true does not mean that a link-local address - // will be assigned right away, or at all. If Duplicate Address - // Detection is enabled, an address will only be assigned if it - // successfully resolves. If it fails, no further attempt will be made - // to auto-generate an IPv6 link-local address. + // will be assigned right away, or at all. If Duplicate Address Detection + // is enabled, an address will only be assigned if it successfully resolves. + // If it fails, no further attempt will be made to auto-generate an IPv6 + // link-local address. // // The generated link-local address will follow RFC 4291 Appendix A // guidelines. @@ -479,6 +517,10 @@ type Options struct { // RawFactory produces raw endpoints. Raw endpoints are enabled only if // this is non-nil. RawFactory RawFactory + + // OpaqueIIDOpts hold the options for generating opaque interface identifiers + // (IIDs) as outlined by RFC 7217. + OpaqueIIDOpts OpaqueInterfaceIdentifierOptions } // TransportEndpointInfo holds useful information about a transport endpoint @@ -549,6 +591,7 @@ func New(opts Options) *Stack { autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, + opaqueIIDOpts: opts.OpaqueIIDOpts, } // Add specified network protocols. @@ -662,11 +705,31 @@ func (s *Stack) Stats() tcpip.Stats { } // SetForwarding enables or disables the packet forwarding between NICs. +// +// When forwarding becomes enabled, any host-only state on all NICs will be +// cleaned up. func (s *Stack) SetForwarding(enable bool) { // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. s.mu.Lock() + defer s.mu.Unlock() + + // If forwarding status didn't change, do nothing further. + if s.forwarding == enable { + return + } + s.forwarding = enable - s.mu.Unlock() + + // If this stack does not support IPv6, do nothing further. + if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok { + return + } + + if enable { + for _, nic := range s.nics { + nic.becomeIPv6Router() + } + } } // Forwarding returns if the packet forwarding between NICs is enabled. @@ -735,7 +798,7 @@ func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNum // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. -func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { +func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled bool) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -744,7 +807,7 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, return tcpip.ErrDuplicateNICID } - n := newNIC(s, id, name, ep, loopback) + n := newNIC(s, id, name, ep) s.nics[id] = n if enabled { @@ -756,32 +819,26 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, // CreateNIC creates a NIC with the provided id and link-layer endpoint. func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, "", ep, true, false) + return s.createNIC(id, "", ep, true) } // CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, // and a human-readable name. func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, true, false) -} - -// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer -// endpoint, and a human-readable name. -func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, true, true) + return s.createNIC(id, name, ep, true) } // CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, // but leave it disable. Stack.EnableNIC must be called before the link-layer // endpoint starts delivering packets to it. func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, "", ep, false, false) + return s.createNIC(id, "", ep, false) } // CreateDisabledNamedNIC is a combination of CreateNamedNIC and // CreateDisabledNIC. func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, false, false) + return s.createNIC(id, name, ep, false) } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -848,7 +905,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Up: true, // Netstack interfaces are always up. Running: nic.linkEP.IsAttached(), Promiscuous: nic.isPromiscuousMode(), - Loopback: nic.linkEP.Capabilities()&CapabilityLoopback != 0, + Loopback: nic.isLoopback(), } nics[id] = NICInfo{ Name: nic.name, @@ -973,9 +1030,11 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { return nics } -// GetMainNICAddress returns the first primary address and prefix for the given -// NIC and protocol. Returns an error if the NIC doesn't exist and an empty -// value if the NIC doesn't have a primary address for the given protocol. +// GetMainNICAddress returns the first non-deprecated primary address and prefix +// for the given NIC and protocol. If no non-deprecated primary address exists, +// a deprecated primary address and prefix will be returned. Returns an error if +// the NIC doesn't exist and an empty value if the NIC doesn't have a primary +// address for the given protocol. func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -985,12 +1044,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID } - for _, a := range nic.PrimaryAddresses() { - if a.Protocol == protocol { - return a.AddressWithPrefix, nil - } - } - return tcpip.AddressWithPrefix{}, nil + return nic.primaryAddress(protocol), nil } func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { @@ -1012,7 +1066,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok { if ref := s.getRefEP(nic, localAddr, netProto); ref != nil { - return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback), nil + return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil } } } else { @@ -1028,7 +1082,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n remoteAddr = ref.ep.ID().LocalAddress } - r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback) + r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) if needRoute { r.NextHop = route.Gateway } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 8fc034ca1..33f20579f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -27,10 +27,12 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -122,7 +124,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -133,7 +135,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params b[1] = f.id.LocalAddress[0] b[2] = byte(params.Protocol) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) @@ -141,7 +143,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } @@ -149,11 +151,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -1894,55 +1896,67 @@ func TestNICForwarding(t *testing.T) { } // TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses -// (or lack there-of if disabled (default)). Note, DAD will be disabled in -// these tests. +// using the modified EUI-64 of the NIC's MAC address (or lack there-of if +// disabled (default)). Note, DAD will be disabled in these tests. func TestNICAutoGenAddr(t *testing.T) { tests := []struct { name string autoGen bool linkAddr tcpip.LinkAddress + iidOpts stack.OpaqueInterfaceIdentifierOptions shouldGen bool }{ { "Disabled", false, linkAddr1, + stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(nicID tcpip.NICID, _ string) string { + return fmt.Sprintf("nic%d", nicID) + }, + }, false, }, { "Enabled", true, linkAddr1, + stack.OpaqueInterfaceIdentifierOptions{}, true, }, { "Nil MAC", true, tcpip.LinkAddress([]byte(nil)), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Empty MAC", true, tcpip.LinkAddress(""), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Invalid MAC", true, tcpip.LinkAddress("\x01\x02\x03"), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Multicast MAC", true, tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Unspecified MAC", true, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, } @@ -1951,13 +1965,12 @@ func TestNICAutoGenAddr(t *testing.T) { t.Run(test.name, func(t *testing.T) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + OpaqueIIDOpts: test.iidOpts, } if test.autoGen { - // Only set opts.AutoGenIPv6LinkLocal when - // test.autoGen is true because - // opts.AutoGenIPv6LinkLocal should be false by - // default. + // Only set opts.AutoGenIPv6LinkLocal when test.autoGen is true because + // opts.AutoGenIPv6LinkLocal should be false by default. opts.AutoGenIPv6LinkLocal = true } @@ -1973,9 +1986,130 @@ func TestNICAutoGenAddr(t *testing.T) { } if test.shouldGen { + // Should have auto-generated an address and resolved immediately (DAD + // is disabled). + if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + } + } else { + // Should not have auto-generated an address. + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + } + }) + } +} + +// TestNICAutoGenAddrWithOpaque tests the auto-generation of IPv6 link-local +// addresses with opaque interface identifiers. Link Local addresses should +// always be generated with opaque IIDs if configured to use them, even if the +// NIC has an invalid MAC address. +func TestNICAutoGenAddrWithOpaque(t *testing.T) { + const nicID = 1 + + var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte + n, err := rand.Read(secretKey[:]) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n) + } + + tests := []struct { + name string + nicName string + autoGen bool + linkAddr tcpip.LinkAddress + secretKey []byte + }{ + { + name: "Disabled", + nicName: "nic1", + autoGen: false, + linkAddr: linkAddr1, + secretKey: secretKey[:], + }, + { + name: "Enabled", + nicName: "nic1", + autoGen: true, + linkAddr: linkAddr1, + secretKey: secretKey[:], + }, + // These are all cases where we would not have generated a + // link-local address if opaque IIDs were disabled. + { + name: "Nil MAC and empty nicName", + nicName: "", + autoGen: true, + linkAddr: tcpip.LinkAddress([]byte(nil)), + secretKey: secretKey[:1], + }, + { + name: "Empty MAC and empty nicName", + autoGen: true, + linkAddr: tcpip.LinkAddress(""), + secretKey: secretKey[:2], + }, + { + name: "Invalid MAC", + nicName: "test", + autoGen: true, + linkAddr: tcpip.LinkAddress("\x01\x02\x03"), + secretKey: secretKey[:3], + }, + { + name: "Multicast MAC", + nicName: "test2", + autoGen: true, + linkAddr: tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + secretKey: secretKey[:4], + }, + { + name: "Unspecified MAC and nil SecretKey", + nicName: "test3", + autoGen: true, + linkAddr: tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: test.secretKey, + }, + } + + if test.autoGen { + // Only set opts.AutoGenIPv6LinkLocal when + // test.autoGen is true because + // opts.AutoGenIPv6LinkLocal should be false by + // default. + opts.AutoGenIPv6LinkLocal = true + } + + e := channel.New(10, 1280, test.linkAddr) + s := stack.New(opts) + if err := s.CreateNamedNIC(nicID, test.nicName, e); err != nil { + t.Fatalf("CreateNamedNIC(%d, %q, _) = %s", nicID, test.nicName, err) + } + + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + + if test.autoGen { // Should have auto-generated an address and // resolved immediately (DAD is disabled). - if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddrWithOpaqueIID(test.nicName, 0, test.secretKey), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) } } else { @@ -1988,6 +2122,55 @@ func TestNICAutoGenAddr(t *testing.T) { } } +// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are +// not auto-generated for loopback NICs. +func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { + const nicID = 1 + const nicName = "nicName" + + tests := []struct { + name string + opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions + }{ + { + name: "IID From MAC", + opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{}, + }, + { + name: "Opaque IID", + opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: true, + OpaqueIIDOpts: test.opaqueIIDOpts, + } + + e := loopback.New() + s := stack.New(opts) + if err := s.CreateNamedNIC(nicID, nicName, e); err != nil { + t.Fatalf("CreateNamedNIC(%d, %q, _) = %s", nicID, nicName, err) + } + + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want) + } + }) + } +} + // TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 // link-local addresses will only be assigned after the DAD process resolves. func TestNICAutoGenAddrDoesDAD(t *testing.T) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index bd5eb89ca..f62fd729f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -308,7 +308,7 @@ type ControlMessages struct { // HasTimestamp indicates whether Timestamp is valid/set. HasTimestamp bool - // Timestamp is the time (in ns) that the last packed used to create + // Timestamp is the time (in ns) that the last packet used to create // the read data was received. Timestamp int64 @@ -317,6 +317,18 @@ type ControlMessages struct { // Inq is the number of bytes ready to be received. Inq int32 + + // HasTOS indicates whether Tos is valid/set. + HasTOS bool + + // TOS is the IPv4 type of service of the associated packet. + TOS int8 + + // HasTClass indicates whether Tclass is valid/set. + HasTClass bool + + // Tclass is the IPv6 traffic class of the associated packet. + TClass int32 } // Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) @@ -564,6 +576,11 @@ type KeepaliveIntervalOption time.Duration // closed. type KeepaliveCountOption int +// TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user +// specified timeout for a given TCP connection. +// See: RFC5482 for details. +type TCPUserTimeoutOption time.Duration + // CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get // the current congestion control algorithm. type CongestionControlOption string @@ -912,6 +929,14 @@ type TCPStats struct { // ESTABLISHED state or the CLOSE-WAIT state. EstablishedResets *StatCounter + // EstablishedClosed is the number of times established TCP connections + // made a transition to CLOSED state. + EstablishedClosed *StatCounter + + // EstablishedTimedout is the number of times an established connection + // was reset because of keep-alive time out. + EstablishedTimedout *StatCounter + // ListenOverflowSynDrop is the number of times the listen queue overflowed // and a SYN was dropped. ListenOverflowSynDrop *StatCounter diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index dd1728f9c..3b353d56c 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -28,6 +28,7 @@ go_library( "forwarder.go", "protocol.go", "rcv.go", + "rcv_state.go", "reno.go", "sack.go", "sack_scoreboard.go", @@ -52,6 +53,7 @@ go_library( "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", "//pkg/tcpip/iptables", + "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index f543a6105..5422ae80c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -242,6 +242,13 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() + // Now inherit any socket options that should be inherited from the + // listening endpoint. + // In case of Forwarder listenEP will be nil and hence this check. + if l.listenEP != nil { + l.listenEP.propagateInheritableOptions(n) + } + // Register new endpoint so that packets are routed to it. if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { n.Close() @@ -298,8 +305,6 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } ep.mu.Lock() - ep.stack.Stats().TCP.CurrentEstablished.Increment() - ep.state = StateEstablished ep.isConnectNotified = true ep.mu.Unlock() @@ -352,6 +357,14 @@ func (e *endpoint) deliverAccepted(n *endpoint) { } } +// propagateInheritableOptions propagates any options set on the listening +// endpoint to the newly created endpoint. +func (e *endpoint) propagateInheritableOptions(n *endpoint) { + e.mu.Lock() + n.userTimeout = e.userTimeout + e.mu.Unlock() +} + // handleSynSegment is called in its own goroutine once the listening endpoint // receives a SYN segment. It is responsible for completing the handshake and // queueing the new endpoint for acceptance. @@ -546,6 +559,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. + // We do not use transitionToStateEstablishedLocked here as there is + // no handshake state available when doing a SYN cookie based accept. n.stack.Stats().TCP.CurrentEstablished.Increment() n.state = StateEstablished n.isConnectNotified = true diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 4206db8b6..cdd69f360 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -218,6 +218,14 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // acceptable if the ack field acknowledges the SYN. if s.flagIsSet(header.TCPFlagRst) { if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 { + // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK + // was acceptable then signal the user "error: connection reset", drop + // the segment, enter CLOSED state, delete TCB, and return." + h.ep.mu.Lock() + h.ep.workerCleanup = true + h.ep.mu.Unlock() + // Although the RFC above calls out ECONNRESET, Linux actually returns + // ECONNREFUSED here so we do as well. return tcpip.ErrConnectionRefused } return nil @@ -252,6 +260,11 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // and the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { h.state = handshakeCompleted + + h.ep.mu.Lock() + h.ep.transitionToStateEstablishedLocked(h) + h.ep.mu.Unlock() + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) return nil } @@ -352,6 +365,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted + h.ep.mu.Lock() + h.ep.transitionToStateEstablishedLocked(h) + h.ep.mu.Unlock() + return nil } @@ -853,7 +870,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { } e.state = StateError e.HardError = err - if err != tcpip.ErrConnectionReset { + if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { // The exact sequence number to be used for the RST is the same as the // one used by Linux. We need to handle the case of window being shrunk // which can cause sndNxt to be outside the acceptable window on the @@ -880,6 +897,30 @@ func (e *endpoint) completeWorkerLocked() { } } +// transitionToStateEstablisedLocked transitions a given endpoint +// to an established state using the handshake parameters provided. +// It also initializes sender/receiver if required. +func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { + if e.snd == nil { + // Transfer handshake state to TCP connection. We disable + // receive window scaling if the peer doesn't support it + // (indicated by a negative send window scale). + e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + } + if e.rcv == nil { + rcvBufSize := seqnum.Size(e.receiveBufferSize()) + e.rcvListMu.Lock() + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + // Bootstrap the auto tuning algorithm. Starting at zero will + // result in a really large receive window after the first auto + // tuning adjustment. + e.rcvAutoParams.prevCopied = int(h.rcvWnd) + e.rcvListMu.Unlock() + } + h.ep.stack.Stats().TCP.CurrentEstablished.Increment() + e.state = StateEstablished +} + // transitionToStateCloseLocked ensures that the endpoint is // cleaned up from the transport demuxer, "before" moving to // StateClose. This will ensure that no packet will be @@ -891,6 +932,7 @@ func (e *endpoint) transitionToStateCloseLocked() { } e.cleanupLocked() e.state = StateClose + e.stack.Stats().TCP.EstablishedClosed.Increment() } // tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed @@ -953,20 +995,6 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { func (e *endpoint) handleSegments() *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { - e.mu.RLock() - state := e.state - e.mu.RUnlock() - if state == StateClose { - // When we get into StateClose while processing from the queue, - // return immediately and let the protocolMainloop handle it. - // - // We can reach StateClose only while processing a previous segment - // or a notification from the protocolMainLoop (caller goroutine). - // This means that with this return, the segment dequeue below can - // never occur on a closed endpoint. - return nil - } - s := e.segmentQueue.dequeue() if s == nil { checkRequeue = false @@ -1024,6 +1052,24 @@ func (e *endpoint) handleSegments() *tcpip.Error { s.decRef() continue } + + // Now check if the received segment has caused us to transition + // to a CLOSED state, if yes then terminate processing and do + // not invoke the sender. + e.mu.RLock() + state := e.state + e.mu.RUnlock() + if state == StateClose { + // When we get into StateClose while processing from the queue, + // return immediately and let the protocolMainloop handle it. + // + // We can reach StateClose only while processing a previous segment + // or a notification from the protocolMainLoop (caller goroutine). + // This means that with this return, the segment dequeue below can + // never occur on a closed endpoint. + s.decRef() + return nil + } e.snd.handleRcvdSegment(s) } s.decRef() @@ -1049,14 +1095,27 @@ func (e *endpoint) handleSegments() *tcpip.Error { // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { + e.mu.RLock() + userTimeout := e.userTimeout + e.mu.RUnlock() + e.keepalive.Lock() if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { e.keepalive.Unlock() return nil } + // If a userTimeout is set then abort the connection if it is + // exceeded. + if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { + e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() + return tcpip.ErrTimeout + } + if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } @@ -1073,7 +1132,6 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { // whether it is enabled for this endpoint. func (e *endpoint) resetKeepaliveTimer(receivedData bool) { e.keepalive.Lock() - defer e.keepalive.Unlock() if receivedData { e.keepalive.unacked = 0 } @@ -1081,6 +1139,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { // data to send. if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { e.keepalive.timer.disable() + e.keepalive.Unlock() return } if e.keepalive.unacked > 0 { @@ -1088,6 +1147,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } else { e.keepalive.timer.enable(e.keepalive.idle) } + e.keepalive.Unlock() } // disableKeepaliveTimer stops the keepalive timer. @@ -1142,8 +1202,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.lastErrorMu.Unlock() e.mu.Lock() - e.stack.Stats().TCP.EstablishedResets.Increment() - e.stack.Stats().TCP.CurrentEstablished.Decrement() e.state = StateError e.HardError = err @@ -1152,25 +1210,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return err } - - // Transfer handshake state to TCP connection. We disable - // receive window scaling if the peer doesn't support it - // (indicated by a negative send window scale). - e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - - rcvBufSize := seqnum.Size(e.receiveBufferSize()) - e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) - // boot strap the auto tuning algorithm. Starting at zero will - // result in a large step function on the first proper causing - // the window to just go to a really large value after the first - // RTT itself. - e.rcvAutoParams.prevCopied = initialRcvWnd - e.rcvListMu.Unlock() - e.stack.Stats().TCP.CurrentEstablished.Increment() - e.mu.Lock() - e.state = StateEstablished - e.mu.Unlock() } e.keepalive.timer.init(&e.keepalive.waker) @@ -1221,6 +1260,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { w: &e.snd.resendWaker, f: func() *tcpip.Error { if !e.snd.retransmitTimerExpired() { + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } return nil @@ -1371,7 +1411,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Mark endpoint as closed. e.mu.Lock() if e.state != StateError { - e.stack.Stats().TCP.EstablishedResets.Increment() e.stack.Stats().TCP.CurrentEstablished.Decrement() e.transitionToStateCloseLocked() } @@ -1388,6 +1427,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { if s == nil { break } + e.tryDeliverSegmentFromClosedEndpoint(s) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 9d4a87e30..fe629aa40 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tmutex" @@ -340,9 +341,11 @@ type endpoint struct { // TCP should never broadcast but Linux nevertheless supports enabling/ // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool + // Values used to reserve a port or register a transport endpoint // (which ever happens first). boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -472,6 +475,12 @@ type endpoint struct { // without hearing a response, the connection is closed. keepalive keepalive + // userTimeout if non-zero specifies a user specified timeout for + // a connection w/ pending data to send. A connection that has pending + // unacked data will be forcibily aborted if the timeout is reached + // without any data being acked. + userTimeout time.Duration + // pendingAccepted is a synchronization primitive used to track number // of connections that are queued up to be delivered to the accepted // channel. We use this to ensure that all goroutines blocked on writing @@ -737,9 +746,10 @@ func (e *endpoint) Close() { e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } // Mark endpoint as closed. @@ -800,10 +810,11 @@ func (e *endpoint) cleanupLocked() { } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false } e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} e.route.Release() e.stack.CompleteTransportEndpointCleanup(e) @@ -1329,6 +1340,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.notifyProtocolGoroutine(notifyKeepaliveChanged) return nil + case tcpip.TCPUserTimeoutOption: + e.mu.Lock() + e.userTimeout = time.Duration(v) + e.mu.Unlock() + return nil + case tcpip.BroadcastOption: e.mu.Lock() e.broadcast = v != 0 @@ -1587,6 +1604,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.keepalive.Unlock() return nil + case *tcpip.TCPUserTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPUserTimeoutOption(e.userTimeout) + e.mu.Unlock() + return nil + case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. *o = 1 @@ -1775,7 +1798,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } // reusePort is false below because connect cannot reuse a port even if // reusePort was set. - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) { + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) { return false, nil } @@ -1802,7 +1825,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // before Connect: in such a case we don't want to hold on to // reservations anymore. if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false } @@ -1951,6 +1974,15 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { return nil } + if e.state == StateInitial { + // The listen is called on an unbound socket, the socket is + // automatically bound to a random free port with the local + // address set to INADDR_ANY. + if err := e.bindLocked(tcpip.FullAddress{}); err != nil { + return err + } + } + // Endpoint must be bound before it can transition to listen mode. if e.state != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() @@ -2010,6 +2042,10 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() + return e.bindLocked(addr) +} + +func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. @@ -2034,28 +2070,33 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice) + flags := ports.Flags{ + LoadBalanced: e.reusePort, + } + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice) if err != nil { return err } e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = flags e.isPortReserved = true e.effectiveNetProtos = netProtos e.ID.LocalPort = port // Any failures beyond this point must remove the port registration. - defer func(bindToDevice tcpip.NICID) { + defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice) e.isPortReserved = false e.effectiveNetProtos = nil e.ID.LocalPort = 0 e.ID.LocalAddress = "" e.boundNICID = 0 e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } - }(e.boundBindToDevice) + }(e.boundPortFlags, e.boundBindToDevice) // If an address is specified, we must ensure that it's one of our // local addresses. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 89b965c23..bc718064c 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -162,13 +162,26 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo func replyWithReset(s *segment) { // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) + ack := seqnum.Value(0) + flags := byte(header.TCPFlagRst) + // As per RFC 793 page 35 (Reset Generation) + // 1. If the connection does not exist (CLOSED) then a reset is sent + // in response to any incoming segment except another reset. In + // particular, SYNs addressed to a non-existent connection are rejected + // by this means. + + // If the incoming segment has an ACK field, the reset takes its + // sequence number from the ACK field of the segment, otherwise the + // reset has sequence number zero and the ACK field is set to the sum + // of the sequence number and segment length of the incoming segment. + // The connection remains in the CLOSED state. if s.flagIsSet(header.TCPFlagAck) { seq = s.ackNumber + } else { + flags |= header.TCPFlagAck + ack = s.sequenceNumber.Add(s.logicalLen()) } - - ack := s.sequenceNumber.Add(s.logicalLen()) - - sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } // SetOption implements TransportProtocol.SetOption. diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 857dc445f..0a5534959 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -50,16 +50,20 @@ type receiver struct { pendingRcvdSegments segmentHeap pendingBufUsed seqnum.Size pendingBufSize seqnum.Size + + // Time when the last ack was received. + lastRcvdAckTime time.Time `state:".(unixTime)"` } func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { return &receiver{ - ep: ep, - rcvNxt: irs + 1, - rcvAcc: irs.Add(rcvWnd + 1), - rcvWnd: rcvWnd, - rcvWndScale: rcvWndScale, - pendingBufSize: pendingBufSize, + ep: ep, + rcvNxt: irs + 1, + rcvAcc: irs.Add(rcvWnd + 1), + rcvWnd: rcvWnd, + rcvWndScale: rcvWndScale, + pendingBufSize: pendingBufSize, + lastRcvdAckTime: time.Now(), } } @@ -205,7 +209,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. - if s.flagIsSet(header.TCPFlagAck) { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { r.ep.mu.Lock() switch r.ep.state { case StateFinWait1: @@ -360,6 +364,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { return true, nil } + // Store the time of the last ack. + r.lastRcvdAckTime = time.Now() + // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go new file mode 100644 index 000000000..2bf21a2e7 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rcv_state.go @@ -0,0 +1,29 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveLastRcvdAckTime is invoked by stateify. +func (r *receiver) saveLastRcvdAckTime() unixTime { + return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()} +} + +// loadLastRcvdAckTime is invoked by stateify. +func (r *receiver) loadLastRcvdAckTime(unix unixTime) { + r.lastRcvdAckTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index d3f7c9125..8a947dc66 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -28,8 +28,11 @@ import ( ) const ( - // minRTO is the minimum allowed value for the retransmit timeout. - minRTO = 200 * time.Millisecond + // MinRTO is the minimum allowed value for the retransmit timeout. + MinRTO = 200 * time.Millisecond + + // MaxRTO is the maximum allowed value for the retransmit timeout. + MaxRTO = 120 * time.Second // InitialCwnd is the initial congestion window. InitialCwnd = 10 @@ -134,6 +137,10 @@ type sender struct { // rttMeasureTime is the time when the rttMeasureSeqNum was sent. rttMeasureTime time.Time `state:".(unixTime)"` + // firstRetransmittedSegXmitTime is the original transmit time of + // the first segment that was retransmitted due to RTO expiration. + firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` + closed bool writeNext *segment writeList segmentList @@ -392,8 +399,8 @@ func (s *sender) updateRTO(rtt time.Duration) { s.rto = s.rtt.srtt + 4*s.rtt.rttvar s.rtt.Unlock() - if s.rto < minRTO { - s.rto = minRTO + if s.rto < MinRTO { + s.rto = MinRTO } } @@ -438,8 +445,30 @@ func (s *sender) retransmitTimerExpired() bool { s.ep.stack.Stats().TCP.Timeouts.Increment() s.ep.stats.SendErrors.Timeouts.Increment() - // Give up if we've waited more than a minute since the last resend. - if s.rto >= 60*time.Second { + // Give up if we've waited more than a minute since the last resend or + // if a user time out is set and we have exceeded the user specified + // timeout since the first retransmission. + s.ep.mu.RLock() + uto := s.ep.userTimeout + s.ep.mu.RUnlock() + + if s.firstRetransmittedSegXmitTime.IsZero() { + // We store the original xmitTime of the segment that we are + // about to retransmit as the retransmission time. This is + // required as by the time the retransmitTimer has expired the + // segment has already been sent and unacked for the RTO at the + // time the segment was sent. + s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime + } + + elapsed := time.Since(s.firstRetransmittedSegXmitTime) + remaining := MaxRTO + if uto != 0 { + // Cap to the user specified timeout if one is specified. + remaining = uto - elapsed + } + + if remaining <= 0 || s.rto >= MaxRTO { return false } @@ -447,6 +476,11 @@ func (s *sender) retransmitTimerExpired() bool { // below. s.rto *= 2 + // Cap RTO to remaining time. + if s.rto > remaining { + s.rto = remaining + } + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. // // Retransmit timeouts: @@ -674,7 +708,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se default: s.ep.state = StateFinWait1 } - s.ep.stack.Stats().TCP.CurrentEstablished.Decrement() s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. @@ -1169,6 +1202,8 @@ func (s *sender) handleRcvdSegment(seg *segment) { // RFC 6298 Rule 5.3 if s.sndUna == s.sndNxt { s.outstanding = 0 + // Reset firstRetransmittedSegXmitTime to the zero value. + s.firstRetransmittedSegXmitTime = time.Time{} s.resendTimer.disable() } } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go index 12eff8afc..8b20c3455 100644 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ b/pkg/tcpip/transport/tcp/snd_state.go @@ -48,3 +48,13 @@ func (s *sender) loadRttMeasureTime(unix unixTime) { func (s *sender) afterLoad() { s.resendTimer.init(&s.resendWaker) } + +// saveFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime { + return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()} +} + +// loadFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) { + s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index c4b45aa6f..e8fe4dab5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -75,6 +75,20 @@ func TestGiveUpConnect(t *testing.T) { if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted) } + + // Call Connect again to retreive the handshake failure status + // and stats updates. + if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { + t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted) + } + + if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestConnectIncrementActiveConnection(t *testing.T) { @@ -309,8 +323,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) } func TestTCPResetsReceivedIncrement(t *testing.T) { @@ -446,18 +460,17 @@ func TestConnectResetAfterClose(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), ), ) break } } -// TestClosingWithEnqueuedSegments tests handling of -// still enqueued segments when the endpoint transitions -// to StateClose. The in-flight segments would be re-enqueued -// to a any listening endpoint. +// TestClosingWithEnqueuedSegments tests handling of still enqueued segments +// when the endpoint transitions to StateClose. The in-flight segments would be +// re-enqueued to a any listening endpoint. func TestClosingWithEnqueuedSegments(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -541,21 +554,29 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { ep.(interface{ ResumeWork() }).ResumeWork() // Wait for the protocolMainLoop to resume and update state. - time.Sleep(1 * time.Millisecond) + time.Sleep(10 * time.Millisecond) // Expect the endpoint to be closed. if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } + // Check if the endpoint was moved to CLOSED and netstack a reset in // response to the ACK packet that we sent after last-ACK. checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(793), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), ), ) } @@ -792,6 +813,82 @@ func TestSendRstOnListenerRxSynAckV6(t *testing.T) { checker.SeqNum(200))) } +// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, +// peers can send data and expect a response within a reasonable ammount of time +// without calling Accept on the listening endpoint first. +// +// This test uses IPv4. +func TestTCPAckBeforeAcceptV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + // Send data before accepting the connection. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, +// peers can send data and expect a response within a reasonable ammount of time +// without calling Accept on the listening endpoint first. +// +// This test uses IPv6. +func TestTCPAckBeforeAcceptV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */) + + // Send data before accepting the connection. + c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + func TestSendRstOnListenerRxAckV4(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -816,7 +913,7 @@ func TestSendRstOnListenerRxAckV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.TCPFlags(header.TCPFlagRst), checker.SeqNum(200))) } @@ -844,7 +941,7 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.TCPFlags(header.TCPFlagRst), checker.SeqNum(200))) } @@ -1043,6 +1140,71 @@ func TestConnectBindToDevice(t *testing.T) { } } +func TestRstOnSynSent(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create an endpoint, don't handshake because we want to interfere with the + // handshake process. + c.Create(-1) + + // Start connection attempt. + waitEntry, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} + if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted { + t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, tcpip.ErrConnectStarted) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + // Ensure that we've reached SynSent state + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + // Send a packet with a proper ACK and a RST flag to cause the socket + // to Error and close out + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagRst | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for packet to arrive") + } + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused { + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrConnectionRefused) + } + + // Due to the RST the endpoint should be in an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } +} + func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -2618,6 +2780,13 @@ loop: if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } + + if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { + t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestSendOnResetConnection(t *testing.T) { @@ -4186,8 +4355,9 @@ func TestKeepalive(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + const keepAliveInterval = 10 * time.Millisecond c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) @@ -4277,19 +4447,43 @@ func TestKeepalive(t *testing.T) { ) } + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + 5*time.Millisecond) + // The connection should be terminated after 5 unacked keepalives. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(next)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), ), ) + if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got) + } + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { @@ -4303,7 +4497,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki RcvWnd: 30000, }) - // Receive the SYN-ACK reply.w + // Receive the SYN-ACK reply. b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) iss = seqnum.Value(tcp.SequenceNumber()) @@ -4336,6 +4530,50 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki return irs, iss } +func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + // Send a SYN request. + irs = seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetV6Packet() + tcp := header.TCP(header.IPv6(b).Payload()) + iss = seqnum.Value(tcp.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(irs) + 1), + } + + if synCookieInUse { + // When cookies are in use window scaling is disabled. + tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ + WS: -1, + MSS: c.MSSWithoutOptionsV6(), + })) + } + + checker.IPv6(t, b, checker.TCP(tcpCheckers...)) + + // Send ACK. + c.SendV6Packet(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + return irs, iss +} + // TestListenBacklogFull tests that netstack does not complete handshakes if the // listen backlog for the endpoint is full. func TestListenBacklogFull(t *testing.T) { @@ -5512,6 +5750,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: iss, + RcvWnd: 30000, }) // Receive the SYN-ACK reply. @@ -5630,6 +5869,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: iss, + RcvWnd: 30000, }) // Receive the SYN-ACK reply. @@ -5736,6 +5976,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: iss, + RcvWnd: 30000, }) // Receive the SYN-ACK reply. @@ -5809,6 +6050,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: iss, + RcvWnd: 30000, }) c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) @@ -5821,6 +6063,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: iss, + RcvWnd: 30000, }) // Receive the SYN-ACK reply. @@ -5867,6 +6110,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) } + want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 + wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { @@ -5887,6 +6132,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: iss, + RcvWnd: 30000, }) // Receive the SYN-ACK reply. @@ -5992,6 +6238,326 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(uint32(ackHeaders.SeqNum)), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) + + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } +} + +func TestTCPCloseWithData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + RcvWnd: 30000, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now trigger a passive close by sending a FIN. + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + RcvWnd: 30000, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now write a few bytes and then close the endpoint. + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b = c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } + + c.EP.Close() + // Check the FIN. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))), + checker.AckNum(uint32(iss+2)), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + // First send a partial ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)-1), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now send a full ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now ACK the FIN. + ackHeaders.AckNum++ + c.SendPacket(nil, ackHeaders) + + // Now send an ACK and we should get a RST back as the endpoint should + // be in CLOSED state. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Check the RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) +} + +func TestTCPUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + + userTimeout := 50 * time.Millisecond + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Send some data and wait before ACKing it. + view := buffer.NewView(3) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Wait for a little over the minimum retransmit timeout of 200ms for + // the retransmitTimer to fire and close the connection. + time.Sleep(tcp.MinRTO + 10*time.Millisecond) + + // No packet should be received as the connection should be silently + // closed due to timeout. + c.CheckNoPacket("unexpected packet received after userTimeout has expired") + + next += uint32(len(view)) + + // The connection should be terminated after userTimeout has expired. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(next)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + } + + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + } +} + +func TestKeepaliveWithUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + + const keepAliveInterval = 10 * time.Millisecond + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) + c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10)) + c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + + // Set userTimeout to be the duration for 3 keepalive probes. + userTimeout := 30 * time.Millisecond + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Check that the connection is still alive. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + // Now receive 2 keepalives, but don't ACK them. The connection should + // be reset when the 3rd one should be sent due to userTimeout being + // 30ms and each keepalive probe should be sent 10ms apart as set above after + // the connection has been idle for 10ms. + for i := 0; i < 2; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + 5*time.Millisecond) + + // The connection should be terminated after 30ms. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(c.IRS + 1), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + } + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + } } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 04fdaaed1..b0a376eba 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -231,6 +231,7 @@ func (c *Context) CheckNoPacket(errMsg string) { // addresses. It will fail with an error if no packet is received for // 2 seconds. func (c *Context) GetPacket() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv4.ProtocolNumber { @@ -259,6 +260,7 @@ func (c *Context) GetPacket() []byte { // and destination address. If no packet is available it will return // nil immediately. func (c *Context) GetPacketNonBlocking() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv4.ProtocolNumber { @@ -483,6 +485,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) { // GetV6Packet reads a single packet from the link layer endpoint of the context // and asserts that it is an IPv6 Packet with the expected src/dest addresses. func (c *Context) GetV6Packet() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv6.ProtocolNumber { @@ -1089,3 +1092,9 @@ func (c *Context) SetGSOEnabled(enable bool) { func (c *Context) MSSWithoutOptions() uint16 { return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) } + +// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no +// options are in use for IPv6 packets. +func (c *Context) MSSWithoutOptionsV6() uint16 { + return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize) +} diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 8d4c3808f..97e4d5825 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -34,6 +34,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/iptables", + "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/waiter", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 24cb88c13..1ac4705af 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -107,6 +108,7 @@ type endpoint struct { // Values used to reserve a port or register a transport endpoint. // (which ever happens first). boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags // sendTOS represents IPv4 TOS or IPv6 TrafficClass, // applied while sending packets. Defaults to 0 as on Linux. @@ -180,8 +182,9 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } for _, mem := range e.multicastMemberships { @@ -895,7 +898,8 @@ func (e *endpoint) Disconnect() *tcpip.Error { } else { if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.boundPortFlags = ports.Flags{} } e.state = StateInitial } @@ -1042,16 +1046,23 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice) + flags := ports.Flags{ + LoadBalanced: e.reusePort, + // FIXME(b/129164367): Support SO_REUSEADDR. + MostRecent: false, + } + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, flags, e.bindToDevice) if err != nil { return id, e.bindToDevice, err } + e.boundPortFlags = flags id.LocalPort = port } err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice) + e.boundPortFlags = ports.Flags{} } return id, e.bindToDevice, err } @@ -1134,9 +1145,14 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() + addr := e.ID.LocalAddress + if e.state == StateConnected { + addr = e.route.LocalAddress + } + return tcpip.FullAddress{ NIC: e.RegisterNICID, - Addr: e.ID.LocalAddress, + Addr: addr, Port: e.ID.LocalPort, }, nil } |