diff options
Diffstat (limited to 'pkg/tcpip')
30 files changed, 2556 insertions, 792 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 65d4d0cd8..e07ebd153 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -10,6 +10,7 @@ go_library( "packet_buffer_state.go", "tcpip.go", "time_unsafe.go", + "timer.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip", visibility = ["//visibility:public"], @@ -26,3 +27,10 @@ go_test( srcs = ["tcpip_test.go"], embed = [":tcpip"], ) + +go_test( + name = "timer_test", + size = "small", + srcs = ["timer_test.go"], + deps = [":tcpip"], +) diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index f1d837196..f2061c778 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -44,6 +44,7 @@ go_test( ], deps = [ ":header", + "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", "@com_github_google_go-cmp//cmp:go_default_library", diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index fc671e439..135a60b12 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -15,6 +15,7 @@ package header import ( + "crypto/sha256" "encoding/binary" "strings" @@ -102,6 +103,11 @@ const ( // bytes including and after the IIDOffsetInIPv6Address-th byte are // for the IID. IIDOffsetInIPv6Address = 8 + + // OpaqueIIDSecretKeyMinBytes is the recommended minimum number of bytes + // for the secret key used to generate an opaque interface identifier as + // outlined by RFC 7217. + OpaqueIIDSecretKeyMinBytes = 16 ) // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the @@ -326,3 +332,42 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool { } return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } + +// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier +// (IID) to buf as outlined by RFC 7217 and returns the extended buffer. +// +// The opaque IID is generated from the cryptographic hash of the concatenation +// of the prefix, NIC's name, DAD counter (DAD retry counter) and the secret +// key. The secret key SHOULD be at least OpaqueIIDSecretKeyMinBytes bytes and +// MUST be generated to a pseudo-random number. See RFC 4086 for randomness +// requirements for security. +// +// If buf has enough capacity for the IID (IIDSize bytes), a new underlying +// array for the buffer will not be allocated. +func AppendOpaqueInterfaceIdentifier(buf []byte, prefix tcpip.Subnet, nicName string, dadCounter uint8, secretKey []byte) []byte { + // As per RFC 7217 section 5, the opaque identifier can be generated as a + // cryptographic hash of the concatenation of each of the function parameters. + // Note, we omit the optional Network_ID field. + h := sha256.New() + // h.Write never returns an error. + h.Write([]byte(prefix.ID()[:IIDOffsetInIPv6Address])) + h.Write([]byte(nicName)) + h.Write([]byte{dadCounter}) + h.Write(secretKey) + + var sumBuf [sha256.Size]byte + sum := h.Sum(sumBuf[:0]) + + return append(buf, sum[:IIDSize]...) +} + +// LinkLocalAddrWithOpaqueIID computes the default IPv6 link-local address with +// an opaque IID. +func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []byte) tcpip.Address { + lladdrb := [IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, + } + + return tcpip.Address(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey)) +} diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index 42c5c6fc1..1994003ed 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -15,9 +15,12 @@ package header_test import ( + "bytes" + "crypto/sha256" "testing" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -43,3 +46,163 @@ func TestLinkLocalAddr(t *testing.T) { t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want) } } + +func TestAppendOpaqueInterfaceIdentifier(t *testing.T) { + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte + if n, err := rand.Read(secretKeyBuf[:]); err != nil { + t.Fatalf("rand.Read(_): %s", err) + } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) + } + + tests := []struct { + name string + prefix tcpip.Subnet + nicName string + dadCounter uint8 + secretKey []byte + }{ + { + name: "SecretKey of minimum size", + prefix: header.IPv6LinkLocalPrefix.Subnet(), + nicName: "eth0", + dadCounter: 0, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], + }, + { + name: "SecretKey of less than minimum size", + prefix: func() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: "\x01\x02\x03\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + return addrWithPrefix.Subnet() + }(), + nicName: "eth10", + dadCounter: 1, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], + }, + { + name: "SecretKey of more than minimum size", + prefix: func() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: "\x01\x02\x03\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + return addrWithPrefix.Subnet() + }(), + nicName: "eth11", + dadCounter: 2, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], + }, + { + name: "Nil SecretKey and empty nicName", + prefix: func() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + return addrWithPrefix.Subnet() + }(), + nicName: "", + dadCounter: 3, + secretKey: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + h := sha256.New() + h.Write([]byte(test.prefix.ID()[:header.IIDOffsetInIPv6Address])) + h.Write([]byte(test.nicName)) + h.Write([]byte{test.dadCounter}) + if k := test.secretKey; k != nil { + h.Write(k) + } + var hashSum [sha256.Size]byte + h.Sum(hashSum[:0]) + want := hashSum[:header.IIDSize] + + // Passing a nil buffer should result in a new buffer returned with the + // IID. + if got := header.AppendOpaqueInterfaceIdentifier(nil, test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { + t.Errorf("got AppendOpaqueInterfaceIdentifier(nil, %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) + } + + // Passing a buffer with sufficient capacity for the IID should populate + // the buffer provided. + var iidBuf [header.IIDSize]byte + if got := header.AppendOpaqueInterfaceIdentifier(iidBuf[:0], test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { + t.Errorf("got AppendOpaqueInterfaceIdentifier(iidBuf[:0], %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) + } + if got := iidBuf[:]; !bytes.Equal(got, want) { + t.Errorf("got iidBuf = %x, want = %x", got, want) + } + }) + } +} + +func TestLinkLocalAddrWithOpaqueIID(t *testing.T) { + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte + if n, err := rand.Read(secretKeyBuf[:]); err != nil { + t.Fatalf("rand.Read(_): %s", err) + } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) + } + + prefix := header.IPv6LinkLocalPrefix.Subnet() + + tests := []struct { + name string + prefix tcpip.Subnet + nicName string + dadCounter uint8 + secretKey []byte + }{ + { + name: "SecretKey of minimum size", + nicName: "eth0", + dadCounter: 0, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], + }, + { + name: "SecretKey of less than minimum size", + nicName: "eth10", + dadCounter: 1, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], + }, + { + name: "SecretKey of more than minimum size", + nicName: "eth11", + dadCounter: 2, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], + }, + { + name: "Nil SecretKey and empty nicName", + nicName: "", + dadCounter: 3, + secretKey: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + addrBytes := [header.IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, + } + + want := tcpip.Address(header.AppendOpaqueInterfaceIdentifier( + addrBytes[:header.IIDOffsetInIPv6Address], + prefix, + test.nicName, + test.dadCounter, + test.secretKey, + )) + + if got := header.LinkLocalAddrWithOpaqueIID(test.nicName, test.dadCounter, test.secretKey); got != want { + t.Errorf("got LinkLocalAddrWithOpaqueIID(%s, %d, %x) = %s, want = %s", test.nicName, test.dadCounter, test.secretKey, got, want) + } + }) + } +} diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index da8482509..42cacb8a6 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -79,16 +79,16 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketLooping, tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4144a7837..f1bc33adf 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -239,7 +239,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -480,7 +480,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e645cf62c..4ee3d5b45 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -238,11 +238,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in // the PacketBuffer's Data field. views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) @@ -256,7 +256,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw loopedR.Release() } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { @@ -270,11 +270,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { - if loop&stack.PacketLoop != 0 { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { + if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return len(pkts), nil } @@ -289,7 +289,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. ip := header.IPv4(pkt.Data.First()) @@ -324,10 +324,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLo ip.SetChecksum(0) ip.SetChecksum(^ip.CalculateChecksum()) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { e.HandlePacket(r, pkt.Clone()) } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index e13f1fabf..58c3c79b9 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -112,11 +112,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in // the PacketBuffer's Data field. views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) @@ -130,7 +130,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw loopedR.Release() } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } @@ -139,11 +139,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { - if loop&stack.PacketLoop != 0 { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { + if r.Loop&stack.PacketLoop != 0 { panic("not implemented") } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return len(pkts), nil } @@ -161,7 +161,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { // TODO(b/146666412): Support IPv6 header-included packets. return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 69077669a..826fca4de 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -52,23 +52,21 @@ go_test( name = "stack_x_test", size = "small", srcs = [ - "ndp_test.go", "stack_test.go", "transport_demuxer_test.go", "transport_test.go", ], deps = [ ":stack", + "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go-cmp//cmp:go_default_library", @@ -85,3 +83,23 @@ go_test( "//pkg/tcpip", ], ) + +go_test( + name = "ndp_test", + size = "small", + srcs = ["ndp_test.go"], + deps = [ + ":stack", + "//pkg/rand", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/udp", + "//pkg/waiter", + "@com_github_google_go-cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index d9ab59336..a9dd322db 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -115,6 +115,30 @@ var ( MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour ) +// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an +// NDP Router Advertisement informed the Stack about. +type DHCPv6ConfigurationFromNDPRA int + +const ( + // DHCPv6NoConfiguration indicates that no configurations are available via + // DHCPv6. + DHCPv6NoConfiguration DHCPv6ConfigurationFromNDPRA = iota + + // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6. + // + // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6 + // will return all available configuration information. + DHCPv6ManagedAddress + + // DHCPv6OtherConfigurations indicates that other configuration information is + // available via DHCPv6. + // + // Other configurations are configurations other than addresses. Examples of + // other configurations are recursive DNS server list, DNS search lists and + // default gateway. + DHCPv6OtherConfigurations +) + // NDPDispatcher is the interface integrators of netstack must implement to // receive and handle NDP related events. type NDPDispatcher interface { @@ -169,6 +193,15 @@ type NDPDispatcher interface { // call functions on the stack itself. OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool + // OnAutoGenAddressDeprecated will be called when an auto-generated + // address (as part of SLAAC) has been deprecated, but is still + // considered valid. Note, if an address is invalidated at the same + // time it is deprecated, the deprecation event MAY be omitted. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) + // OnAutoGenAddressInvalidated will be called when an auto-generated // address (as part of SLAAC) has been invalidated. // @@ -185,7 +218,20 @@ type NDPDispatcher interface { // already known DNS servers. If called with known DNS servers, their // valid lifetimes must be refreshed to lifetime (it may be increased, // decreased, or completely invalidated when lifetime = 0). + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) + + // OnDHCPv6Configuration will be called with an updated configuration that is + // available via DHCPv6 for a specified NIC. + // + // NDPDispatcher assumes that the initial configuration available by DHCPv6 is + // DHCPv6NoConfiguration. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) } // NDPConfigurations is the NDP configurations for the netstack. @@ -272,6 +318,9 @@ type ndpState struct { // The addresses generated by SLAAC. autoGenAddresses map[tcpip.Address]autoGenAddressState + + // The last learned DHCPv6 configuration from an NDP RA. + dhcpv6Configuration DHCPv6ConfigurationFromNDPRA } // dadState holds the Duplicate Address Detection timer and channel to signal @@ -290,71 +339,27 @@ type dadState struct { // defaultRouterState holds data associated with a default router discovered by // a Router Advertisement (RA). type defaultRouterState struct { - invalidationTimer *time.Timer - - // Used to inform the timer not to invalidate the default router (R) in - // a race condition (T1 is a goroutine that handles an RA from R and T2 - // is the goroutine that handles R's invalidation timer firing): - // T1: Receive a new RA from R - // T1: Obtain the NIC's lock before processing the RA - // T2: R's invalidation timer fires, and gets blocked on obtaining the - // NIC's lock - // T1: Refreshes/extends R's lifetime & releases NIC's lock - // T2: Obtains NIC's lock & invalidates R immediately - // - // To resolve this, T1 will check to see if the timer already fired, and - // inform the timer using doNotInvalidate to not invalidate R, so that - // once T2 obtains the lock, it will see that it is set to true and do - // nothing further. - doNotInvalidate *bool + invalidationTimer tcpip.CancellableTimer } // onLinkPrefixState holds data associated with an on-link prefix discovered by // a Router Advertisement's Prefix Information option (PI) when the NDP // configurations was configured to do so. type onLinkPrefixState struct { - invalidationTimer *time.Timer - - // Used to signal the timer not to invalidate the on-link prefix (P) in - // a race condition (T1 is a goroutine that handles a PI for P and T2 - // is the goroutine that handles P's invalidation timer firing): - // T1: Receive a new PI for P - // T1: Obtain the NIC's lock before processing the PI - // T2: P's invalidation timer fires, and gets blocked on obtaining the - // NIC's lock - // T1: Refreshes/extends P's lifetime & releases NIC's lock - // T2: Obtains NIC's lock & invalidates P immediately - // - // To resolve this, T1 will check to see if the timer already fired, and - // inform the timer using doNotInvalidate to not invalidate P, so that - // once T2 obtains the lock, it will see that it is set to true and do - // nothing further. - doNotInvalidate *bool + invalidationTimer tcpip.CancellableTimer } // autoGenAddressState holds data associated with an address generated via // SLAAC. type autoGenAddressState struct { - invalidationTimer *time.Timer - - // Used to signal the timer not to invalidate the SLAAC address (A) in - // a race condition (T1 is a goroutine that handles a PI for A and T2 - // is the goroutine that handles A's invalidation timer firing): - // T1: Receive a new PI for A - // T1: Obtain the NIC's lock before processing the PI - // T2: A's invalidation timer fires, and gets blocked on obtaining the - // NIC's lock - // T1: Refreshes/extends A's lifetime & releases NIC's lock - // T2: Obtains NIC's lock & invalidates A immediately - // - // To resolve this, T1 will check to see if the timer already fired, and - // inform the timer using doNotInvalidate to not invalidate A, so that - // once T2 obtains the lock, it will see that it is set to true and do - // nothing further. - doNotInvalidate *bool - - // Nonzero only when the address is not valid forever (invalidationTimer - // is not nil). + // A reference to the referencedNetworkEndpoint that this autoGenAddressState + // is holding state for. + ref *referencedNetworkEndpoint + + deprecationTimer tcpip.CancellableTimer + invalidationTimer tcpip.CancellableTimer + + // Nonzero only when the address is not valid forever. validUntil time.Time } @@ -556,7 +561,7 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { // handleRA handles a Router Advertisement message that arrived on the NIC // this ndp is for. Does nothing if the NIC is configured to not handle RAs. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // Is the NIC configured to handle RAs at all? // @@ -568,6 +573,28 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { return } + // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we + // only inform the dispatcher on configuration changes. We do nothing else + // with the information. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + var configuration DHCPv6ConfigurationFromNDPRA + switch { + case ra.ManagedAddrConfFlag(): + configuration = DHCPv6ManagedAddress + + case ra.OtherConfFlag(): + configuration = DHCPv6OtherConfigurations + + default: + configuration = DHCPv6NoConfiguration + } + + if ndp.dhcpv6Configuration != configuration { + ndp.dhcpv6Configuration = configuration + ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration) + } + } + // Is the NIC configured to discover default routers? if ndp.configs.DiscoverDefaultRouters { rtr, ok := ndp.defaultRouters[ip] @@ -585,27 +612,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { case ok && rl != 0: // This is an already discovered default router. Update // the invalidation timer. - timer := rtr.invalidationTimer - - // We should ALWAYS have an invalidation timer for a - // discovered router. - if timer == nil { - panic("ndphandlera: RA invalidation timer should not be nil") - } - - if !timer.Stop() { - // If we reach this point, then we know the - // timer fired after we already took the NIC - // lock. Inform the timer not to invalidate the - // router when it obtains the lock as we just - // got a new RA that refreshes its lifetime to a - // non-zero value. See - // defaultRouterState.doNotInvalidate for more - // details. - *rtr.doNotInvalidate = true - } - - timer.Reset(rl) + rtr.invalidationTimer.StopLocked() + rtr.invalidationTimer.Reset(rl) + ndp.defaultRouters[ip] = rtr case ok && rl == 0: // We know about the router but it is no longer to be @@ -672,10 +681,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { return } - rtr.invalidationTimer.Stop() - rtr.invalidationTimer = nil - *rtr.doNotInvalidate = true - rtr.doNotInvalidate = nil + rtr.invalidationTimer.StopLocked() delete(ndp.defaultRouters, ip) @@ -704,27 +710,15 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { return } - // Used to signal the timer not to invalidate the default router (R) in - // a race condition. See defaultRouterState.doNotInvalidate for more - // details. - var doNotInvalidate bool - - ndp.defaultRouters[ip] = defaultRouterState{ - invalidationTimer: time.AfterFunc(rl, func() { - ndp.nic.stack.mu.Lock() - defer ndp.nic.stack.mu.Unlock() - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - - if doNotInvalidate { - doNotInvalidate = false - return - } - + state := defaultRouterState{ + invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { ndp.invalidateDefaultRouter(ip) }), - doNotInvalidate: &doNotInvalidate, } + + state.invalidationTimer.Reset(rl) + + ndp.defaultRouters[ip] = state } // rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6 @@ -746,21 +740,17 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) return } - // Used to signal the timer not to invalidate the on-link prefix (P) in - // a race condition. See onLinkPrefixState.doNotInvalidate for more - // details. - var doNotInvalidate bool - var timer *time.Timer + state := onLinkPrefixState{ + invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + ndp.invalidateOnLinkPrefix(prefix) + }), + } - // Only create a timer if the lifetime is not infinite. if l < header.NDPInfiniteLifetime { - timer = ndp.prefixInvalidationCallback(prefix, l, &doNotInvalidate) + state.invalidationTimer.Reset(l) } - ndp.onLinkPrefixes[prefix] = onLinkPrefixState{ - invalidationTimer: timer, - doNotInvalidate: &doNotInvalidate, - } + ndp.onLinkPrefixes[prefix] = state } // invalidateOnLinkPrefix invalidates a discovered on-link prefix. @@ -775,13 +765,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { return } - if s.invalidationTimer != nil { - s.invalidationTimer.Stop() - s.invalidationTimer = nil - *s.doNotInvalidate = true - } - - s.doNotInvalidate = nil + s.invalidationTimer.StopLocked() delete(ndp.onLinkPrefixes, prefix) @@ -791,28 +775,6 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { } } -// prefixInvalidationCallback returns a new on-link prefix invalidation timer -// for prefix that fires after vl. -// -// doNotInvalidate is used to signal the timer when it fires at the same time -// that a prefix's valid lifetime gets refreshed. See -// onLinkPrefixState.doNotInvalidate for more details. -func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Duration, doNotInvalidate *bool) *time.Timer { - return time.AfterFunc(vl, func() { - ndp.nic.stack.mu.Lock() - defer ndp.nic.stack.mu.Unlock() - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - - if *doNotInvalidate { - *doNotInvalidate = false - return - } - - ndp.invalidateOnLinkPrefix(prefix) - }) -} - // handleOnLinkPrefixInformation handles a Prefix Information option with // its on-link flag set, as per RFC 4861 section 6.3.4. // @@ -852,42 +814,17 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio // This is an already discovered on-link prefix with a // new non-zero valid lifetime. + // // Update the invalidation timer. - timer := prefixState.invalidationTimer - - if timer == nil && vl >= header.NDPInfiniteLifetime { - // Had infinite valid lifetime before and - // continues to have an invalid lifetime. Do - // nothing further. - return - } - if timer != nil && !timer.Stop() { - // If we reach this point, then we know the timer alread fired - // after we took the NIC lock. Inform the timer to not - // invalidate the prefix once it obtains the lock as we just - // got a new PI that refreshes its lifetime to a non-zero value. - // See onLinkPrefixState.doNotInvalidate for more details. - *prefixState.doNotInvalidate = true - } - - if vl >= header.NDPInfiniteLifetime { - // Prefix is now valid forever so we don't need - // an invalidation timer. - prefixState.invalidationTimer = nil - ndp.onLinkPrefixes[prefix] = prefixState - return - } + prefixState.invalidationTimer.StopLocked() - if timer != nil { - // We already have a timer so just reset it to - // expire after the new valid lifetime. - timer.Reset(vl) - return + if vl < header.NDPInfiniteLifetime { + // Prefix is valid for a finite lifetime, reset the timer to expire after + // the new valid lifetime. + prefixState.invalidationTimer.Reset(vl) } - // We do not have a timer so just create a new one. - prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate) ndp.onLinkPrefixes[prefix] = prefixState } @@ -897,7 +834,7 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio // handleAutonomousPrefixInformation assumes that the prefix this pi is for is // not the link-local prefix and the autonomous flag is set. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) { vl := pi.ValidLifetime() pl := pi.PreferredLifetime() @@ -912,103 +849,30 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform prefix := pi.Subnet() // Check if we already have an auto-generated address for prefix. - for _, ref := range ndp.nic.endpoints { - if ref.protocol != header.IPv6ProtocolNumber { - continue - } - - if ref.configType != slaac { - continue - } - - addr := ref.ep.ID().LocalAddress - refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: ref.ep.PrefixLen()} + for addr, addrState := range ndp.autoGenAddresses { + refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: addrState.ref.ep.PrefixLen()} if refAddrWithPrefix.Subnet() != prefix { continue } - // - // At this point, we know we are refreshing a SLAAC generated - // IPv6 address with the prefix, prefix. Do the work as outlined - // by RFC 4862 section 5.5.3.e. - // - - addrState, ok := ndp.autoGenAddresses[addr] - if !ok { - panic(fmt.Sprintf("must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr)) - } - - // TODO(b/143713887): Handle deprecating auto-generated address - // after the preferred lifetime. - - // As per RFC 4862 section 5.5.3.e, the valid lifetime of the - // address generated by SLAAC is as follows: - // - // 1) If the received Valid Lifetime is greater than 2 hours or - // greater than RemainingLifetime, set the valid lifetime of - // the address to the advertised Valid Lifetime. - // - // 2) If RemainingLifetime is less than or equal to 2 hours, - // ignore the advertised Valid Lifetime. - // - // 3) Otherwise, reset the valid lifetime of the address to 2 - // hours. - - // Handle the infinite valid lifetime separately as we do not - // keep a timer in this case. - if vl >= header.NDPInfiniteLifetime { - if addrState.invalidationTimer != nil { - // Valid lifetime was finite before, but now it - // is valid forever. - if !addrState.invalidationTimer.Stop() { - *addrState.doNotInvalidate = true - } - addrState.invalidationTimer = nil - addrState.validUntil = time.Time{} - ndp.autoGenAddresses[addr] = addrState - } - - return - } - - var effectiveVl time.Duration - var rl time.Duration - - // If the address was originally set to be valid forever, - // assume the remaining time to be the maximum possible value. - if addrState.invalidationTimer == nil { - rl = header.NDPInfiniteLifetime - } else { - rl = time.Until(addrState.validUntil) - } - - if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { - effectiveVl = vl - } else if rl <= MinPrefixInformationValidLifetimeForUpdate { - ndp.autoGenAddresses[addr] = addrState - return - } else { - effectiveVl = MinPrefixInformationValidLifetimeForUpdate - } - - if addrState.invalidationTimer == nil { - addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate) - } else { - if !addrState.invalidationTimer.Stop() { - *addrState.doNotInvalidate = true - } - addrState.invalidationTimer.Reset(effectiveVl) - } - - addrState.validUntil = time.Now().Add(effectiveVl) - ndp.autoGenAddresses[addr] = addrState + // At this point, we know we are refreshing a SLAAC generated IPv6 address + // with the prefix prefix. Do the work as outlined by RFC 4862 section + // 5.5.3.e. + ndp.refreshAutoGenAddressLifetimes(addr, pl, vl) return } // We do not already have an address within the prefix, prefix. Do the // work as outlined by RFC 4862 section 5.5.3.d if n is configured // to auto-generated global addresses by SLAAC. + ndp.newAutoGenAddress(prefix, pl, vl) +} +// newAutoGenAddress generates a new SLAAC address with the provided lifetimes +// for prefix. +// +// pl is the new preferred lifetime. vl is the new valid lifetime. +func (ndp *ndpState) newAutoGenAddress(prefix tcpip.Subnet, pl, vl time.Duration) { // Are we configured to auto-generate new global addresses? if !ndp.configs.AutoGenGlobalAddresses { return @@ -1028,22 +892,24 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform return } - // Only attempt to generate an interface-specific IID if we have a valid - // link address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address - // (provided by LinkEndpoint.LinkAddress) before reaching this - // point. - linkAddr := ndp.nic.linkEP.LinkAddress() - if !header.IsValidUnicastEthernetAddress(linkAddr) { - return - } + addrBytes := []byte(prefix.ID()) + if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + addrBytes = header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), 0 /* dadCounter */, oIID.SecretKey) + } else { + // Only attempt to generate an interface-specific IID if we have a valid + // link address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + linkAddr := ndp.nic.linkEP.LinkAddress() + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return + } - // Generate an address within prefix from the modified EUI-64 of ndp's - // NIC's Ethernet MAC address. - addrBytes := make([]byte, header.IPv6AddressSize) - copy(addrBytes[:header.IIDOffsetInIPv6Address], prefix.ID()[:header.IIDOffsetInIPv6Address]) - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + // Generate an address within prefix from the modified EUI-64 of ndp's NIC's + // Ethernet MAC address. + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + } addr := tcpip.Address(addrBytes) addrWithPrefix := tcpip.AddressWithPrefix{ Address: addr, @@ -1065,29 +931,132 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform return } - if _, err := ndp.nic.addAddressLocked(tcpip.ProtocolAddress{ + protocolAddr := tcpip.ProtocolAddress{ Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addrWithPrefix, - }, FirstPrimaryEndpoint, permanent, slaac); err != nil { - panic(err) + } + // If the preferred lifetime is zero, then the address should be considered + // deprecated. + deprecated := pl == 0 + ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) + if err != nil { + log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err) + } + + state := autoGenAddressState{ + ref: ref, + deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + addrState, ok := ndp.autoGenAddresses[addr] + if !ok { + log.Fatalf("ndp: must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr) + } + addrState.ref.deprecated = true + ndp.notifyAutoGenAddressDeprecated(addr) + }), + invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + ndp.invalidateAutoGenAddress(addr) + }), } - // Setup the timers to deprecate and invalidate this newly generated + // Setup the initial timers to deprecate and invalidate this newly generated // address. - // TODO(b/143713887): Handle deprecating auto-generated addresses - // after the preferred lifetime. + if !deprecated && pl < header.NDPInfiniteLifetime { + state.deprecationTimer.Reset(pl) + } - var doNotInvalidate bool - var vTimer *time.Timer if vl < header.NDPInfiniteLifetime { - vTimer = ndp.autoGenAddrInvalidationTimer(addr, vl, &doNotInvalidate) + state.invalidationTimer.Reset(vl) + state.validUntil = time.Now().Add(vl) } - ndp.autoGenAddresses[addr] = autoGenAddressState{ - invalidationTimer: vTimer, - doNotInvalidate: &doNotInvalidate, - validUntil: time.Now().Add(vl), + ndp.autoGenAddresses[addr] = state +} + +// refreshAutoGenAddressLifetimes refreshes the lifetime of a SLAAC generated +// address addr. +// +// pl is the new preferred lifetime. vl is the new valid lifetime. +func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl time.Duration) { + addrState, ok := ndp.autoGenAddresses[addr] + if !ok { + log.Fatalf("ndp: SLAAC state not found to refresh lifetimes for %s", addr) + } + defer func() { ndp.autoGenAddresses[addr] = addrState }() + + // If the preferred lifetime is zero, then the address should be considered + // deprecated. + deprecated := pl == 0 + wasDeprecated := addrState.ref.deprecated + addrState.ref.deprecated = deprecated + + // Only send the deprecation event if the deprecated status for addr just + // changed from non-deprecated to deprecated. + if !wasDeprecated && deprecated { + ndp.notifyAutoGenAddressDeprecated(addr) + } + + // If addr was preferred for some finite lifetime before, stop the deprecation + // timer so it can be reset. + addrState.deprecationTimer.StopLocked() + + // Reset the deprecation timer if addr has a finite preferred lifetime. + if !deprecated && pl < header.NDPInfiniteLifetime { + addrState.deprecationTimer.Reset(pl) + } + + // As per RFC 4862 section 5.5.3.e, the valid lifetime of the address + // + // + // 1) If the received Valid Lifetime is greater than 2 hours or greater than + // RemainingLifetime, set the valid lifetime of the address to the + // advertised Valid Lifetime. + // + // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the + // advertised Valid Lifetime. + // + // 3) Otherwise, reset the valid lifetime of the address to 2 hours. + + // Handle the infinite valid lifetime separately as we do not keep a timer in + // this case. + if vl >= header.NDPInfiniteLifetime { + addrState.invalidationTimer.StopLocked() + addrState.validUntil = time.Time{} + return + } + + var effectiveVl time.Duration + var rl time.Duration + + // If the address was originally set to be valid forever, assume the remaining + // time to be the maximum possible value. + if addrState.validUntil == (time.Time{}) { + rl = header.NDPInfiniteLifetime + } else { + rl = time.Until(addrState.validUntil) + } + + if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { + effectiveVl = vl + } else if rl <= MinPrefixInformationValidLifetimeForUpdate { + return + } else { + effectiveVl = MinPrefixInformationValidLifetimeForUpdate + } + + addrState.invalidationTimer.StopLocked() + addrState.invalidationTimer.Reset(effectiveVl) + addrState.validUntil = time.Now().Add(effectiveVl) +} + +// notifyAutoGenAddressDeprecated notifies the stack's NDP dispatcher that addr +// has been deprecated. +func (ndp *ndpState) notifyAutoGenAddressDeprecated(addr tcpip.Address) { + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + }) } } @@ -1111,19 +1080,12 @@ func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) { // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool { state, ok := ndp.autoGenAddresses[addr] - if !ok { return false } - if state.invalidationTimer != nil { - state.invalidationTimer.Stop() - state.invalidationTimer = nil - *state.doNotInvalidate = true - } - - state.doNotInvalidate = nil - + state.deprecationTimer.StopLocked() + state.invalidationTimer.StopLocked() delete(ndp.autoGenAddresses, addr) if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { @@ -1136,26 +1098,6 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo return true } -// autoGenAddrInvalidationTimer returns a new invalidation timer for an -// auto-generated address that fires after vl. -// -// doNotInvalidate is used to inform the timer when it fires at the same time -// that an auto-generated address's valid lifetime gets refreshed. See -// autoGenAddrState.doNotInvalidate for more details. -func (ndp *ndpState) autoGenAddrInvalidationTimer(addr tcpip.Address, vl time.Duration, doNotInvalidate *bool) *time.Timer { - return time.AfterFunc(vl, func() { - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - - if *doNotInvalidate { - *doNotInvalidate = false - return - } - - ndp.invalidateAutoGenAddress(addr) - }) -} - // cleanupHostOnlyState cleans up any state that is only useful for hosts. // // cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 64a9a2b20..108762b6e 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package stack_test +package ndp_test import ( "encoding/binary" @@ -21,6 +21,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -29,6 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) const ( @@ -45,6 +48,10 @@ var ( llAddr1 = header.LinkLocalAddr(linkAddr1) llAddr2 = header.LinkLocalAddr(linkAddr2) llAddr3 = header.LinkLocalAddr(linkAddr3) + dstAddr = tcpip.FullAddress{ + Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + Port: 25, + } ) func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix { @@ -135,6 +142,7 @@ type ndpAutoGenAddrEventType int const ( newAddr ndpAutoGenAddrEventType = iota + deprecatedAddr invalidatedAddr ) @@ -154,18 +162,24 @@ type ndpRDNSSEvent struct { rdnss ndpRDNSS } +type ndpDHCPv6Event struct { + nicID tcpip.NICID + configuration stack.DHCPv6ConfigurationFromNDPRA +} + var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) // ndpDispatcher implements NDPDispatcher so tests can know when various NDP // related events happen for test purposes. type ndpDispatcher struct { - dadC chan ndpDADEvent - routerC chan ndpRouterEvent - rememberRouter bool - prefixC chan ndpPrefixEvent - rememberPrefix bool - autoGenAddrC chan ndpAutoGenAddrEvent - rdnssC chan ndpRDNSSEvent + dadC chan ndpDADEvent + routerC chan ndpRouterEvent + rememberRouter bool + prefixC chan ndpPrefixEvent + rememberPrefix bool + autoGenAddrC chan ndpAutoGenAddrEvent + rdnssC chan ndpRDNSSEvent + dhcpv6ConfigurationC chan ndpDHCPv6Event } // Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. @@ -239,6 +253,16 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi return true } +func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ + nicID, + addr, + deprecatedAddr, + } + } +} + func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { if c := n.autoGenAddrC; c != nil { c <- ndpAutoGenAddrEvent{ @@ -262,11 +286,23 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc } } +// Implements stack.NDPDispatcher.OnDHCPv6Configuration. +func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) { + if c := n.dhcpv6ConfigurationC; c != nil { + c <- ndpDHCPv6Event{ + nicID, + configuration, + } + } +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid // RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. func TestDADResolve(t *testing.T) { + t.Parallel() + tests := []struct { name string dupAddrDetectTransmits uint8 @@ -399,6 +435,8 @@ func TestDADResolve(t *testing.T) { // a node doing DAD for the same address), or if another node is detected to own // the address already (receive an NA message for the tentative address). func TestDADFail(t *testing.T) { + t.Parallel() + tests := []struct { name string makeBuf func(tgt tcpip.Address) buffer.Prependable @@ -542,6 +580,8 @@ func TestDADFail(t *testing.T) { // TestDADStop tests to make sure that the DAD process stops when an address is // removed. func TestDADStop(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } @@ -614,6 +654,71 @@ func TestDADStop(t *testing.T) { } } +// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 +// link-local addresses will only be assigned after the DAD process resolves. +func TestNICAutoGenAddrDoesDAD(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.DefaultNDPConfigurations() + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, + } + + e := channel.New(0, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Address should not be considered bound to the + // NIC yet (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + linkLocalAddr := header.LinkLocalAddr(linkAddr1) + + // Wait for DAD to resolve. + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // We should get a resolution event after 1s (default time to + // resolve as per default NDP configurations). Waiting for that + // resolution time + an extra 1s without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != linkLocalAddr { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + } +} + // TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if // we attempt to update NDP configurations using an invalid NICID. func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { @@ -631,6 +736,8 @@ func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { // configurations without affecting the default NDP configurations or other // interfaces' configurations. func TestSetNDPConfigurations(t *testing.T) { + t.Parallel() + tests := []struct { name string dupAddrDetectTransmits uint8 @@ -779,21 +886,32 @@ func TestSetNDPConfigurations(t *testing.T) { } } -// raBufWithOpts returns a valid NDP Router Advertisement with options. -// -// Note, raBufWithOpts does not populate any of the RA fields other than the -// Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { +// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options +// and DHCPv6 configurations specified. +func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) pkt.SetCode(0) - ra := header.NDPRouterAdvert(pkt.NDPPayload()) + raPayload := pkt.NDPPayload() + ra := header.NDPRouterAdvert(raPayload) + // Populate the Router Lifetime. + binary.BigEndian.PutUint16(raPayload[2:], rl) + // Populate the Managed Address flag field. + if managedAddress { + // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing) + // of the RA payload. + raPayload[1] |= (1 << 7) + } + // Populate the Other Configurations flag field. + if otherConfigurations { + // The Other Configurations flag field is the 6th bit of byte #1 + // (0-indexing) of the RA payload. + raPayload[1] |= (1 << 6) + } opts := ra.Options() opts.Serialize(optSer) - // Populate the Router Lifetime. - binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl) pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -808,6 +926,23 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()} } +// raBufWithOpts returns a valid NDP Router Advertisement with options. +// +// Note, raBufWithOpts does not populate any of the RA fields other than the +// Router Lifetime. +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { + return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) +} + +// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related +// fields set. +// +// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the +// DHCPv6 related ones. +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) tcpip.PacketBuffer { + return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) +} + // raBuf returns a valid NDP Router Advertisement. // // Note, raBuf does not populate any of the RA fields other than the @@ -857,6 +992,8 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on // TestNoRouterDiscovery tests that router discovery will not be performed if // configured not to. func TestNoRouterDiscovery(t *testing.T) { + t.Parallel() + // Being configured to discover routers means handle and // discover are set to true and forwarding is set to false. // This tests all possible combinations of the configurations, @@ -869,8 +1006,6 @@ func TestNoRouterDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 1), } @@ -1011,13 +1146,13 @@ func TestRouterDiscovery(t *testing.T) { expectRouterEvent(llAddr2, true) // Rx an RA from another router (lladdr3) with non-zero lifetime. - l3Lifetime := time.Duration(6) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime))) + const l3LifetimeSeconds = 6 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) expectRouterEvent(llAddr3, true) // Rx an RA from lladdr2 with lesser lifetime. - l2Lifetime := time.Duration(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(l2Lifetime))) + const l2LifetimeSeconds = 2 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) select { case <-ndpDisp.routerC: t.Fatal("Should not receive a router event when updating lifetimes for known routers") @@ -1031,7 +1166,7 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2Lifetime*time.Second+defaultTimeout) + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultTimeout) // Rx an RA from lladdr2 with huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) @@ -1048,7 +1183,7 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3Lifetime*time.Second+defaultTimeout) + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultTimeout) } // TestRouterDiscoveryMaxRouters tests that only @@ -1105,6 +1240,8 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { // TestNoPrefixDiscovery tests that prefix discovery will not be performed if // configured not to. func TestNoPrefixDiscovery(t *testing.T) { + t.Parallel() + prefix := tcpip.AddressWithPrefix{ Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), PrefixLen: 64, @@ -1122,8 +1259,6 @@ func TestNoPrefixDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, 1), } @@ -1480,6 +1615,8 @@ func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { // TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. func TestNoAutoGenAddr(t *testing.T) { + t.Parallel() + prefix, _, _ := prefixSubnetAddr(0, "") // Being configured to auto-generate addresses means handle and @@ -1494,8 +1631,6 @@ func TestNoAutoGenAddr(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } @@ -1637,12 +1772,529 @@ func TestAutoGenAddr(t *testing.T) { } } +// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher, +// channel.Endpoint and stack.Stack. +// +// stack.Stack will have a default route through the router (llAddr3) installed +// and a static link-address (linkAddr3) added to the link address cache for the +// router. +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { + t.Helper() + ndpDisp := &ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: ndpDisp, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: llAddr3, + NIC: nicID, + }}) + s.AddLinkAddress(nicID, llAddr3, linkAddr3) + return ndpDisp, e, s +} + +// addrForNewConnection returns the local address used when creating a new +// connection. +func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) + } + if err := ep.Connect(dstAddr); err != nil { + t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) + } + got, err := ep.GetLocalAddress() + if err != nil { + t.Fatalf("ep.GetLocalAddress(): %s", err) + } + return got.Addr +} + +// addrForNewConnectionWithAddr returns the local address used when creating a +// new connection with a specific local address. +func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) + } + if err := ep.Bind(addr); err != nil { + t.Fatalf("ep.Bind(%+v): %s", addr, err) + } + if err := ep.Connect(dstAddr); err != nil { + t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) + } + got, err := ep.GetLocalAddress() + if err != nil { + t.Fatalf("ep.GetLocalAddress(): %s", err) + } + return got.Addr +} + +// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when +// receiving a PI with 0 preferred lifetime. +func TestAutoGenAddrDeprecateFromPI(t *testing.T) { + const nicID = 1 + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() + + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } + + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } + + // Receive PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + expectPrimaryAddr(addr1) + + // Deprecate addr for prefix1 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + // addr should still be the primary endpoint as there are no other addresses. + expectPrimaryAddr(addr1) + + // Refresh lifetimes of addr generated from prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) + + // Deprecate addr for prefix2 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr1 should be the primary endpoint now since addr2 is deprecated but + // addr1 is not. + expectPrimaryAddr(addr1) + // addr2 is deprecated but if explicitly requested, it should be used. + fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + } + + // Another PI w/ 0 preferred lifetime should not result in a deprecation + // event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + } + + // Refresh lifetimes of addr generated from prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) +} + +// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated +// when its preferred lifetime expires. +func TestAutoGenAddrTimerDeprecation(t *testing.T) { + const nicID = 1 + const newMinVL = 2 + newMinVLDuration := newMinVL * time.Second + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() + + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } + + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } + + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) + + // Receive a PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr1) + + // Refresh lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since addr1 is deprecated but + // addr2 is not. + expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. + fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + } + + // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make + // sure we do not get a deprecation event again. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + } + + // Refresh lifetimes for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + // addr1 is the primary endpoint again since it is non-deprecated now. + expectPrimaryAddr(addr1) + + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since it is not deprecated. + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + } + + // Wait for addr of prefix1 to be invalidated. + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultTimeout) + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) + + // Refresh both lifetimes for addr of prefix2 to the same value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + + // Wait for a deprecation then invalidation events, or just an invalidation + // event. We need to cover both cases but cannot deterministically hit both + // cases because the deprecation and invalidation handlers could be handled in + // either deprecation then invalidation, or invalidation then deprecation + // (which should be cancelled by the invalidation handler). + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { + // If we get a deprecation event first, we should get an invalidation + // event almost immediately after. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { + // If we get an invalidation event first, we should not get a deprecation + // event after. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + case <-time.After(defaultTimeout): + } + } else { + t.Fatalf("got unexpected auto-generated event") + } + + case <-time.After(newMinVLDuration + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should not have %s in the list of addresses", addr2) + } + // Should not have any primary endpoints. + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); got != want { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) + } + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) + } + + if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { + t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute) + } +} + +// Tests transitioning a SLAAC address's valid lifetime between finite and +// infinite values. +func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { + const infiniteVLSeconds = 2 + const minVLSeconds = 1 + savedIL := header.NDPInfiniteLifetime + savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL + header.NDPInfiniteLifetime = savedIL + }() + stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second + header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + tests := []struct { + name string + infiniteVL uint32 + }{ + { + name: "EqualToInfiniteVL", + infiniteVL: infiniteVLSeconds, + }, + // Our implementation supports changing header.NDPInfiniteLifetime for tests + // such that a packet can be received where the lifetime field has a value + // greater than header.NDPInfiniteLifetime. Because of this, we test to make + // sure that receiving a value greater than header.NDPInfiniteLifetime is + // handled the same as when receiving a value equal to + // header.NDPInfiniteLifetime. + { + name: "MoreThanInfiniteVL", + infiniteVL: infiniteVLSeconds + 1, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with finite prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + default: + t.Fatal("expected addr auto gen event") + } + + // Receive an new RA with prefix with infinite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) + + // Receive a new RA with prefix with finite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + case <-time.After(minVLSeconds*time.Second + defaultTimeout): + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } + }) +} + // TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an // auto-generated address only gets updated when required to, as specified in // RFC 4862 section 5.5.3.e. func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { const infiniteVL = 4294967295 - const newMinVL = 5 + const newMinVL = 4 saved := stack.MinPrefixInformationValidLifetimeForUpdate defer func() { stack.MinPrefixInformationValidLifetimeForUpdate = saved @@ -1911,6 +2563,110 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { } } +// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use +// opaque interface identifiers when configured to do so. +func TestAutoGenAddrWithOpaqueIID(t *testing.T) { + t.Parallel() + + const nicID = 1 + const nicName = "nic1" + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte + secretKey := secretKeyBuf[:] + n, err := rand.Read(secretKey) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) + } + + prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1) + prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1) + // addr1 and addr2 are the addresses that are expected to be generated when + // stack.Stack is configured to generate opaque interface identifiers as + // defined by RFC 7217. + addrBytes := []byte(subnet1.ID()) + addr1 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)), + PrefixLen: 64, + } + addrBytes = []byte(subnet2.ID()) + addr2 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)), + PrefixLen: 64, + } + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + }) + opts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v, _) = %s", nicID, opts, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive an RA with prefix1 in a PI. + const validLifetimeSecondPrefix1 = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + + // Receive an RA with prefix2 in a PI with a large valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } +} + // TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event // to the integrator when an RA is received with the NDP Recursive DNS Server // option with at least one valid address. @@ -2312,3 +3068,94 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { default: } } + +// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly +// informed when new information about what configurations are available via +// DHCPv6 is learned. +func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { + const nicID = 1 + + ndpDisp := ndpDispatcher{ + dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) { + t.Helper() + select { + case e := <-ndpDisp.dhcpv6ConfigurationC: + if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" { + t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DHCPv6 configuration event") + } + } + + expectNoDHCPv6Event := func() { + t.Helper() + select { + case <-ndpDisp.dhcpv6ConfigurationC: + t.Fatal("unexpected DHCPv6 configuration event") + default: + } + } + + // The initial DHCPv6 configuration should be stack.DHCPv6NoConfiguration. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Other + // Configurations. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + // Receiving the same update again should not result in an event to the + // NDPDispatcher. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Managed Address. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) + expectDHCPv6Event(stack.DHCPv6ManagedAddress) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to none. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectDHCPv6Event(stack.DHCPv6NoConfiguration) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Managed Address. + // + // Note, when the M flag is set, the O flag is redundant. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) + expectDHCPv6Event(stack.DHCPv6ManagedAddress) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) + expectNoDHCPv6Event() + // Even though the DHCPv6 flags are different, the effective configuration is + // the same so we should not receive a new event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) + expectNoDHCPv6Event() + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Other + // Configurations. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectNoDHCPv6Event() +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ddd014658..3810c6602 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -27,11 +27,11 @@ import ( // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { - stack *Stack - id tcpip.NICID - name string - linkEP LinkEndpoint - loopback bool + stack *Stack + id tcpip.NICID + name string + linkEP LinkEndpoint + context NICContext mu sync.RWMutex spoofing bool @@ -85,7 +85,7 @@ const ( ) // newNIC returns a new NIC using the default NDP configurations from stack. -func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC { +func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For // example, make sure that the link address it provides is a valid // unicast ethernet address. @@ -99,7 +99,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback id: id, name: name, linkEP: ep, - loopback: loopback, + context: ctx, primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), @@ -174,23 +174,28 @@ func (n *NIC) enable() *tcpip.Error { return err } - if !n.stack.autoGenIPv6LinkLocal { + // Do not auto-generate an IPv6 link-local address for loopback devices. + if !n.stack.autoGenIPv6LinkLocal || n.isLoopback() { return nil } - l2addr := n.linkEP.LinkAddress() + var addr tcpip.Address + if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey) + } else { + l2addr := n.linkEP.LinkAddress() - // Only attempt to generate the link-local address if we have a - // valid MAC address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address - // (provided by LinkEndpoint.LinkAddress) before reaching this - // point. - if !header.IsValidUnicastEthernetAddress(l2addr) { - return nil - } + // Only attempt to generate the link-local address if we have a valid MAC + // address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + if !header.IsValidUnicastEthernetAddress(l2addr) { + return nil + } - addr := header.LinkLocalAddr(l2addr) + addr = header.LinkLocalAddr(l2addr) + } _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ Protocol: header.IPv6ProtocolNumber, @@ -235,6 +240,10 @@ func (n *NIC) isPromiscuousMode() bool { return rv } +func (n *NIC) isLoopback() bool { + return n.linkEP.Capabilities()&CapabilityLoopback != 0 +} + // setSpoofing enables or disables address spoofing. func (n *NIC) setSpoofing(enable bool) { n.mu.Lock() @@ -244,17 +253,47 @@ func (n *NIC) setSpoofing(enable bool) { // primaryEndpoint returns the primary endpoint of n for the given network // protocol. +// +// primaryEndpoint will return the first non-deprecated endpoint if such an +// endpoint exists. If no non-deprecated endpoint exists, the first deprecated +// endpoint will be returned. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { n.mu.RLock() defer n.mu.RUnlock() + var deprecatedEndpoint *referencedNetworkEndpoint for _, r := range n.primary[protocol] { - if r.isValidForOutgoing() && r.tryIncRef() { - return r + if !r.isValidForOutgoing() { + continue + } + + if !r.deprecated { + if r.tryIncRef() { + // r is not deprecated, so return it immediately. + // + // If we kept track of a deprecated endpoint, decrement its reference + // count since it was incremented when we decided to keep track of it. + if deprecatedEndpoint != nil { + deprecatedEndpoint.decRefLocked() + deprecatedEndpoint = nil + } + + return r + } + } else if deprecatedEndpoint == nil && r.tryIncRef() { + // We prefer an endpoint that is not deprecated, but we keep track of r in + // case n doesn't have any non-deprecated endpoints. + // + // If we end up finding a more preferred endpoint, r's reference count + // will be decremented when such an endpoint is found. + deprecatedEndpoint = r } } - return nil + // n doesn't have any valid non-deprecated endpoints, so return + // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated + // endpoints either). + return deprecatedEndpoint } // hasPermanentAddrLocked returns true if n has a permanent (including currently @@ -362,7 +401,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t Address: address, PrefixLen: netProto.DefaultPrefixLen(), }, - }, peb, temporary, static) + }, peb, temporary, static, false) n.mu.Unlock() return ref @@ -411,10 +450,10 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } } - return n.addAddressLocked(protocolAddress, peb, permanent, static) + return n.addAddressLocked(protocolAddress, peb, permanent, static, false) } -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { // TODO(b/141022673): Validate IP address before adding them. // Sanity check. @@ -450,6 +489,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar protocol: protocolAddress.Protocol, kind: kind, configType: configType, + deprecated: deprecated, } // Set up cache if link address resolution exists for this protocol. @@ -548,6 +588,51 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { return addrs } +// primaryAddress returns the primary address associated with this NIC. +// +// primaryAddress will return the first non-deprecated address if such an +// address exists. If no non-deprecated address exists, the first deprecated +// address will be returned. +func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { + n.mu.RLock() + defer n.mu.RUnlock() + + list, ok := n.primary[proto] + if !ok { + return tcpip.AddressWithPrefix{} + } + + var deprecatedEndpoint *referencedNetworkEndpoint + for _, ref := range list { + // Don't include tentative, expired or tempory endpoints to avoid confusion + // and prevent the caller from using those. + switch ref.getKind() { + case permanentTentative, permanentExpired, temporary: + continue + } + + if !ref.deprecated { + return tcpip.AddressWithPrefix{ + Address: ref.ep.ID().LocalAddress, + PrefixLen: ref.ep.PrefixLen(), + } + } + + if deprecatedEndpoint == nil { + deprecatedEndpoint = ref + } + } + + if deprecatedEndpoint != nil { + return tcpip.AddressWithPrefix{ + Address: deprecatedEndpoint.ep.ID().LocalAddress, + PrefixLen: deprecatedEndpoint.ep.PrefixLen(), + } + } + + return tcpip.AddressWithPrefix{} +} + // AddAddressRange adds a range of addresses to n, so that it starts accepting // packets targeted at the given addresses and network protocol. The range is // given by a subnet address, and all addresses contained in the subnet are @@ -1104,6 +1189,11 @@ type referencedNetworkEndpoint struct { // configType is the method that was used to configure this endpoint. // This must never change after the endpoint is added to a NIC. configType networkEndpointConfigType + + // deprecated indicates whether or not the endpoint should be considered + // deprecated. That is, when deprecated is true, other endpoints that are not + // deprecated should be preferred. + deprecated bool } func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 61fd46d66..2b8751d49 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -234,15 +234,15 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. - WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 34307ae07..517f4b941 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -158,7 +158,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.Pack return tcpip.ErrInvalidEndpointState } - err := r.ref.ep.WritePacket(r, gso, params, r.Loop, pkt) + err := r.ref.ep.WritePacket(r, gso, params, pkt) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { @@ -174,7 +174,7 @@ func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params Network return 0, tcpip.ErrInvalidEndpointState } - n, err := r.ref.ep.WritePackets(r, gso, pkts, params, r.Loop) + n, err := r.ref.ep.WritePackets(r, gso, pkts, params) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n)) } @@ -195,7 +195,7 @@ func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - if err := r.ref.ep.WriteHeaderIncludedPacket(r, r.Loop, pkt); err != nil { + if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 7a9600679..41bf9fd9b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -352,6 +352,38 @@ func (u *uniqueIDGenerator) UniqueID() uint64 { return atomic.AddUint64((*uint64)(u), 1) } +// NICNameFromID is a function that returns a stable name for the specified NIC, +// even if different NIC IDs are used to refer to the same NIC in different +// program runs. It is used when generating opaque interface identifiers (IIDs). +// If the NIC was created with a name, it will be passed to NICNameFromID. +// +// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are +// generated for the same prefix on differnt NICs. +type NICNameFromID func(tcpip.NICID, string) string + +// OpaqueInterfaceIdentifierOptions holds the options related to the generation +// of opaque interface indentifiers (IIDs) as defined by RFC 7217. +type OpaqueInterfaceIdentifierOptions struct { + // NICNameFromID is a function that returns a stable name for a specified NIC, + // even if the NIC ID changes over time. + // + // Must be specified to generate the opaque IID. + NICNameFromID NICNameFromID + + // SecretKey is a pseudo-random number used as the secret key when generating + // opaque IIDs as defined by RFC 7217. The key SHOULD be at least + // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness + // requirements for security as outlined by RFC 4086. SecretKey MUST NOT + // change between program runs, unless explicitly changed. + // + // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey + // MUST NOT be modified after Stack is created. + // + // May be nil, but a nil value is highly discouraged to maintain + // some level of randomness between nodes. + SecretKey []byte +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -412,8 +444,8 @@ type Stack struct { ndpConfigs NDPConfigurations // autoGenIPv6LinkLocal determines whether or not the stack will attempt - // to auto-generate an IPv6 link-local address for newly enabled NICs. - // See the AutoGenIPv6LinkLocal field of Options for more details. + // to auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. autoGenIPv6LinkLocal bool // ndpDisp is the NDP event dispatcher that is used to send the netstack @@ -422,6 +454,10 @@ type Stack struct { // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID + + // opaqueIIDOpts hold the options for generating opaque interface identifiers + // (IIDs) as outlined by RFC 7217. + opaqueIIDOpts OpaqueInterfaceIdentifierOptions } // UniqueID is an abstract generator of unique identifiers. @@ -460,13 +496,15 @@ type Options struct { // before assigning an address to a NIC. NDPConfigs NDPConfigurations - // AutoGenIPv6LinkLocal determins whether or not the stack will attempt - // to auto-generate an IPv6 link-local address for newly enabled NICs. + // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to + // auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. + // // Note, setting this to true does not mean that a link-local address - // will be assigned right away, or at all. If Duplicate Address - // Detection is enabled, an address will only be assigned if it - // successfully resolves. If it fails, no further attempt will be made - // to auto-generate an IPv6 link-local address. + // will be assigned right away, or at all. If Duplicate Address Detection + // is enabled, an address will only be assigned if it successfully resolves. + // If it fails, no further attempt will be made to auto-generate an IPv6 + // link-local address. // // The generated link-local address will follow RFC 4291 Appendix A // guidelines. @@ -479,6 +517,10 @@ type Options struct { // RawFactory produces raw endpoints. Raw endpoints are enabled only if // this is non-nil. RawFactory RawFactory + + // OpaqueIIDOpts hold the options for generating opaque interface identifiers + // (IIDs) as outlined by RFC 7217. + OpaqueIIDOpts OpaqueInterfaceIdentifierOptions } // TransportEndpointInfo holds useful information about a transport endpoint @@ -549,6 +591,7 @@ func New(opts Options) *Stack { autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, + opaqueIIDOpts: opts.OpaqueIIDOpts, } // Add specified network protocols. @@ -753,9 +796,30 @@ func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNum return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue) } -// createNIC creates a NIC with the provided id and link-layer endpoint, and -// optionally enable it. -func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { +// NICContext is an opaque pointer used to store client-supplied NIC metadata. +type NICContext interface{} + +// NICOptions specifies the configuration of a NIC as it is being created. +// The zero value creates an enabled, unnamed NIC. +type NICOptions struct { + // Name specifies the name of the NIC. + Name string + + // Disabled specifies whether to avoid calling Attach on the passed + // LinkEndpoint. + Disabled bool + + // Context specifies user-defined data that will be returned in stack.NICInfo + // for the NIC. Clients of this library can use it to add metadata that + // should be tracked alongside a NIC, to avoid having to keep a + // map[tcpip.NICID]metadata mirroring stack.Stack's nic map. + Context NICContext +} + +// CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and +// NICOptions. See the documentation on type NICOptions for details on how +// NICs can be configured. +func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -764,44 +828,20 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, return tcpip.ErrDuplicateNICID } - n := newNIC(s, id, name, ep, loopback) + n := newNIC(s, id, opts.Name, ep, opts.Context) s.nics[id] = n - if enabled { + if !opts.Disabled { return n.enable() } return nil } -// CreateNIC creates a NIC with the provided id and link-layer endpoint. +// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls +// `LinkEndpoint.Attach` to start delivering packets to it. func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, "", ep, true, false) -} - -// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, -// and a human-readable name. -func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, true, false) -} - -// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer -// endpoint, and a human-readable name. -func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, true, true) -} - -// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, -// but leave it disable. Stack.EnableNIC must be called before the link-layer -// endpoint starts delivering packets to it. -func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, "", ep, false, false) -} - -// CreateDisabledNamedNIC is a combination of CreateNamedNIC and -// CreateDisabledNIC. -func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, false, false) + return s.CreateNICWithOptions(id, ep, NICOptions{}) } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -855,6 +895,18 @@ type NICInfo struct { MTU uint32 Stats NICStats + + // Context is user-supplied data optionally supplied in CreateNICWithOptions. + // See type NICOptions for more details. + Context NICContext +} + +// HasNIC returns true if the NICID is defined in the stack. +func (s *Stack) HasNIC(id tcpip.NICID) bool { + s.mu.RLock() + _, ok := s.nics[id] + s.mu.RUnlock() + return ok } // NICInfo returns a map of NICIDs to their associated information. @@ -868,7 +920,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Up: true, // Netstack interfaces are always up. Running: nic.linkEP.IsAttached(), Promiscuous: nic.isPromiscuousMode(), - Loopback: nic.linkEP.Capabilities()&CapabilityLoopback != 0, + Loopback: nic.isLoopback(), } nics[id] = NICInfo{ Name: nic.name, @@ -877,6 +929,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, + Context: nic.context, } } return nics @@ -993,9 +1046,11 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { return nics } -// GetMainNICAddress returns the first primary address and prefix for the given -// NIC and protocol. Returns an error if the NIC doesn't exist and an empty -// value if the NIC doesn't have a primary address for the given protocol. +// GetMainNICAddress returns the first non-deprecated primary address and prefix +// for the given NIC and protocol. If no non-deprecated primary address exists, +// a deprecated primary address and prefix will be returned. Returns an error if +// the NIC doesn't exist and an empty value if the NIC doesn't have a primary +// address for the given protocol. func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -1005,12 +1060,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID } - for _, a := range nic.PrimaryAddresses() { - if a.Protocol == protocol { - return a.AddressWithPrefix, nil - } - } - return tcpip.AddressWithPrefix{}, nil + return nic.primaryAddress(protocol), nil } func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { @@ -1032,7 +1082,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok { if ref := s.getRefEP(nic, localAddr, netProto); ref != nil { - return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback), nil + return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil } } } else { @@ -1048,7 +1098,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n remoteAddr = ref.ep.ID().LocalAddress } - r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback) + r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) if needRoute { r.NextHop = route.Gateway } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 8fc034ca1..e8de4e87d 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -24,13 +24,14 @@ import ( "sort" "strings" "testing" - "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -48,6 +49,8 @@ const ( // where another value is explicitly used. It is chosen to match the MTU // of loopback interfaces on linux systems. defaultMTU = 65536 + + linkAddr = "\x02\x02\x03\x04\x05\x06" ) // fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and @@ -122,7 +125,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -133,7 +136,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params b[1] = f.id.LocalAddress[0] b[2] = byte(params.Protocol) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) @@ -141,7 +144,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } @@ -149,11 +152,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -1894,55 +1897,67 @@ func TestNICForwarding(t *testing.T) { } // TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses -// (or lack there-of if disabled (default)). Note, DAD will be disabled in -// these tests. +// using the modified EUI-64 of the NIC's MAC address (or lack there-of if +// disabled (default)). Note, DAD will be disabled in these tests. func TestNICAutoGenAddr(t *testing.T) { tests := []struct { name string autoGen bool linkAddr tcpip.LinkAddress + iidOpts stack.OpaqueInterfaceIdentifierOptions shouldGen bool }{ { "Disabled", false, - linkAddr1, + linkAddr, + stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(nicID tcpip.NICID, _ string) string { + return fmt.Sprintf("nic%d", nicID) + }, + }, false, }, { "Enabled", true, - linkAddr1, + linkAddr, + stack.OpaqueInterfaceIdentifierOptions{}, true, }, { "Nil MAC", true, tcpip.LinkAddress([]byte(nil)), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Empty MAC", true, tcpip.LinkAddress(""), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Invalid MAC", true, tcpip.LinkAddress("\x01\x02\x03"), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Multicast MAC", true, tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Unspecified MAC", true, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, } @@ -1951,13 +1966,12 @@ func TestNICAutoGenAddr(t *testing.T) { t.Run(test.name, func(t *testing.T) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + OpaqueIIDOpts: test.iidOpts, } if test.autoGen { - // Only set opts.AutoGenIPv6LinkLocal when - // test.autoGen is true because - // opts.AutoGenIPv6LinkLocal should be false by - // default. + // Only set opts.AutoGenIPv6LinkLocal when test.autoGen is true because + // opts.AutoGenIPv6LinkLocal should be false by default. opts.AutoGenIPv6LinkLocal = true } @@ -1973,8 +1987,8 @@ func TestNICAutoGenAddr(t *testing.T) { } if test.shouldGen { - // Should have auto-generated an address and - // resolved immediately (DAD is disabled). + // Should have auto-generated an address and resolved immediately (DAD + // is disabled). if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) } @@ -1988,66 +2002,215 @@ func TestNICAutoGenAddr(t *testing.T) { } } -// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 -// link-local addresses will only be assigned after the DAD process resolves. -func TestNICAutoGenAddrDoesDAD(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), +// TestNICContextPreservation tests that you can read out via stack.NICInfo the +// Context data you pass via NICContext.Context in stack.CreateNICWithOptions. +func TestNICContextPreservation(t *testing.T) { + var ctx *int + tests := []struct { + name string + opts stack.NICOptions + want stack.NICContext + }{ + { + "context_set", + stack.NICOptions{Context: ctx}, + ctx, + }, + { + "context_not_set", + stack.NICOptions{}, + nil, + }, } - ndpConfigs := stack.DefaultNDPConfigurations() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - AutoGenIPv6LinkLocal: true, - NDPDisp: &ndpDisp, + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + id := tcpip.NICID(1) + ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { + t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) + } + nicinfos := s.NICInfo() + nicinfo, ok := nicinfos[id] + if !ok { + t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos) + } + if got, want := nicinfo.Context == test.want, true; got != want { + t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) + } + }) } +} - e := channel.New(10, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } +// TestNICAutoGenAddrWithOpaque tests the auto-generation of IPv6 link-local +// addresses with opaque interface identifiers. Link Local addresses should +// always be generated with opaque IIDs if configured to use them, even if the +// NIC has an invalid MAC address. +func TestNICAutoGenAddrWithOpaque(t *testing.T) { + const nicID = 1 - // Address should not be considered bound to the - // NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte + n, err := rand.Read(secretKey[:]) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("rand.Read(_): %s", err) } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n) } - linkLocalAddr := header.LinkLocalAddr(linkAddr1) + tests := []struct { + name string + nicName string + autoGen bool + linkAddr tcpip.LinkAddress + secretKey []byte + }{ + { + name: "Disabled", + nicName: "nic1", + autoGen: false, + linkAddr: linkAddr, + secretKey: secretKey[:], + }, + { + name: "Enabled", + nicName: "nic1", + autoGen: true, + linkAddr: linkAddr, + secretKey: secretKey[:], + }, + // These are all cases where we would not have generated a + // link-local address if opaque IIDs were disabled. + { + name: "Nil MAC and empty nicName", + nicName: "", + autoGen: true, + linkAddr: tcpip.LinkAddress([]byte(nil)), + secretKey: secretKey[:1], + }, + { + name: "Empty MAC and empty nicName", + autoGen: true, + linkAddr: tcpip.LinkAddress(""), + secretKey: secretKey[:2], + }, + { + name: "Invalid MAC", + nicName: "test", + autoGen: true, + linkAddr: tcpip.LinkAddress("\x01\x02\x03"), + secretKey: secretKey[:3], + }, + { + name: "Multicast MAC", + nicName: "test2", + autoGen: true, + linkAddr: tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + secretKey: secretKey[:4], + }, + { + name: "Unspecified MAC and nil SecretKey", + nicName: "test3", + autoGen: true, + linkAddr: tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), + }, + } - // Wait for DAD to resolve. - select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): - // We should get a resolution event after 1s (default time to - // resolve as per default NDP configurations). Waiting for that - // resolution time + an extra 1s without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != linkLocalAddr { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: test.secretKey, + }, + } + + if test.autoGen { + // Only set opts.AutoGenIPv6LinkLocal when + // test.autoGen is true because + // opts.AutoGenIPv6LinkLocal should be false by + // default. + opts.AutoGenIPv6LinkLocal = true + } + + e := channel.New(10, 1280, test.linkAddr) + s := stack.New(opts) + nicOpts := stack.NICOptions{Name: test.nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) + } + + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + + if test.autoGen { + // Should have auto-generated an address and + // resolved immediately (DAD is disabled). + if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddrWithOpaqueIID(test.nicName, 0, test.secretKey), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + } + } else { + // Should not have auto-generated an address. + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + } + }) } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) +} + +// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are +// not auto-generated for loopback NICs. +func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { + const nicID = 1 + const nicName = "nicName" + + tests := []struct { + name string + opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions + }{ + { + name: "IID From MAC", + opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{}, + }, + { + name: "Opaque IID", + opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + }, + }, } - if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: true, + OpaqueIIDOpts: test.opaqueIIDOpts, + } + + e := loopback.New() + s := stack.New(opts) + nicOpts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want) + } + }) } } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 3b28b06d0..5e9237de9 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -41,7 +41,7 @@ const ( type testContext struct { t *testing.T - linkEPs map[string]*channel.Endpoint + linkEps map[tcpip.NICID]*channel.Endpoint s *stack.Stack ep tcpip.Endpoint @@ -61,35 +61,29 @@ func (c *testContext) createV6Endpoint(v6only bool) { c.t.Fatalf("NewEndpoint failed: %v", err) } - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.ep.SetSockOpt(v); err != nil { + if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { c.t.Fatalf("SetSockOpt failed: %v", err) } } -// newDualTestContextMultiNic creates the testing context and also linkEpNames -// named NICs. -func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { +// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. +func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - linkEPs := make(map[string]*channel.Endpoint) - for i, linkEpName := range linkEpNames { - channelEP := channel.New(256, mtu, "") - nicID := tcpip.NICID(i + 1) - if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil { + linkEps := make(map[tcpip.NICID]*channel.Endpoint) + for _, linkEpID := range linkEpIDs { + channelEp := channel.New(256, mtu, "") + if err := s.CreateNIC(linkEpID, channelEp); err != nil { t.Fatalf("CreateNIC failed: %v", err) } - linkEPs[linkEpName] = channelEP + linkEps[linkEpID] = channelEp - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { + if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil { t.Fatalf("AddAddress IPv4 failed: %v", err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { + if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil { t.Fatalf("AddAddress IPv6 failed: %v", err) } } @@ -108,7 +102,7 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) return &testContext{ t: t, s: s, - linkEPs: linkEPs, + linkEps: linkEps, } } @@ -125,7 +119,7 @@ func newPayload() []byte { return b } -func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { +func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -156,7 +150,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -186,7 +180,7 @@ func TestTransportDemuxerRegister(t *testing.T) { func TestDistribution(t *testing.T) { type endpointSockopts struct { reuse int - bindToDevice string + bindToDevice tcpip.NICID } for _, test := range []struct { name string @@ -194,71 +188,71 @@ func TestDistribution(t *testing.T) { endpoints []endpointSockopts // wantedDistribution is the wanted ratio of packets received on each // endpoint for each NIC on which packets are injected. - wantedDistributions map[string][]float64 + wantedDistributions map[tcpip.NICID][]float64 }{ { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, + {1, 0}, + {1, 0}, + {1, 0}, + {1, 0}, + {1, 0}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. - "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, + 1: {0.2, 0.2, 0.2, 0.2, 0.2}, }, }, { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - endpointSockopts{0, "dev0"}, - endpointSockopts{0, "dev1"}, - endpointSockopts{0, "dev2"}, + {0, 1}, + {0, 2}, + {0, 3}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. - "dev0": []float64{1, 0, 0}, + 1: {1, 0, 0}, // Injected packets on dev1 go only to the endpoint bound to dev1. - "dev1": []float64{0, 1, 0}, + 2: {0, 1, 0}, // Injected packets on dev2 go only to the endpoint bound to dev2. - "dev2": []float64{0, 0, 1}, + 3: {0, 0, 1}, }, }, { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - endpointSockopts{1, "dev0"}, - endpointSockopts{1, "dev0"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, ""}, + {1, 1}, + {1, 1}, + {1, 2}, + {1, 2}, + {1, 2}, + {1, 0}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to // dev0. - "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, + 1: {0.5, 0.5, 0, 0, 0, 0}, // Injected packets on dev1 get distributed among endpoints bound to // dev1 or unbound. - "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, // Injected packets on dev999 go only to the unbound. - "dev999": []float64{0, 0, 0, 0, 0, 1}, + 1000: {0, 0, 0, 0, 0, 1}, }, }, } { t.Run(test.name, func(t *testing.T) { for device, wantedDistribution := range test.wantedDistributions { - t.Run(device, func(t *testing.T) { - var devices []string + t.Run(string(device), func(t *testing.T) { + var devices []tcpip.NICID for d := range test.wantedDistributions { devices = append(devices, d) } - c := newDualTestContextMultiNic(t, defaultMTU, devices) + c := newDualTestContextMultiNIC(t, defaultMTU, devices) defer c.cleanup() c.createV6Endpoint(false) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 748ce4ea5..f50604a8a 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -102,13 +102,23 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// SetSockOptBool sets a socket option. Currently not supported. +func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + // SetSockOptInt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error { +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f62fd729f..72b5ce179 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -423,17 +423,25 @@ type Endpoint interface { // SetSockOpt sets a socket option. opt should be one of the *Option types. SetSockOpt(opt interface{}) *Error + // SetSockOptBool sets a socket option, for simple cases where a value + // has the bool type. + SetSockOptBool(opt SockOptBool, v bool) *Error + // SetSockOptInt sets a socket option, for simple cases where a value // has the int type. - SetSockOptInt(opt SockOpt, v int) *Error + SetSockOptInt(opt SockOptInt, v int) *Error // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. GetSockOpt(opt interface{}) *Error + // GetSockOptBool gets a socket option for simple cases where a return + // value has the bool type. + GetSockOptBool(SockOptBool) (bool, *Error) + // GetSockOptInt gets a socket option for simple cases where a return // value has the int type. - GetSockOptInt(SockOpt) (int, *Error) + GetSockOptInt(SockOptInt) (int, *Error) // State returns a socket's lifecycle state. The returned value is // protocol-specific and is primarily used for diagnostics. @@ -488,13 +496,22 @@ type WriteOptions struct { Atomic bool } -// SockOpt represents socket options which values have the int type. -type SockOpt int +// SockOptBool represents socket options which values have the bool type. +type SockOptBool int + +const ( + // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 + // socket is to be restricted to sending and receiving IPv6 packets only. + V6OnlyOption SockOptBool = iota +) + +// SockOptInt represents socket options which values have the int type. +type SockOptInt int const ( // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the // number of unread bytes in the input buffer should be returned. - ReceiveQueueSizeOption SockOpt = iota + ReceiveQueueSizeOption SockOptInt = iota // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to // specify the send buffer size option. @@ -521,10 +538,6 @@ const ( // the endpoint should be cleared and returned. type ErrorOption struct{} -// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6 -// socket is to be restricted to sending and receiving IPv6 packets only. -type V6OnlyOption int - // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be // held until segments are full by the TCP transport protocol. type CorkOption int @@ -539,7 +552,7 @@ type ReusePortOption int // BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets // should bind only on a specific NIC. -type BindToDeviceOption string +type BindToDeviceOption NICID // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. type QuickAckOption int diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go new file mode 100644 index 000000000..f5f01f32f --- /dev/null +++ b/pkg/tcpip/timer.go @@ -0,0 +1,161 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "sync" + "time" +) + +// cancellableTimerInstance is a specific instance of CancellableTimer. +// +// Different instances are created each time CancellableTimer is Reset so each +// timer has its own earlyReturn signal. This is to address a bug when a +// CancellableTimer is stopped and reset in quick succession resulting in a +// timer instance's earlyReturn signal being affected or seen by another timer +// instance. +// +// Consider the following sceneario where timer instances share a common +// earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a +// lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second +// (B), third (C), and fourth (D) instance of the timer firing, respectively): +// T1: Obtain L +// T1: Create a new CancellableTimer w/ lock L (create instance A) +// T2: instance A fires, blocked trying to obtain L. +// T1: Attempt to stop instance A (set earlyReturn = true) +// T1: Reset timer (create instance B) +// T3: instance B fires, blocked trying to obtain L. +// T1: Attempt to stop instance B (set earlyReturn = true) +// T1: Reset timer (create instance C) +// T4: instance C fires, blocked trying to obtain L. +// T1: Attempt to stop instance C (set earlyReturn = true) +// T1: Reset timer (create instance D) +// T5: instance D fires, blocked trying to obtain L. +// T1: Release L +// +// Now that T1 has released L, any of the 4 timer instances can take L and check +// earlyReturn. If the timers simply check earlyReturn and then do nothing +// further, then instance D will never early return even though it was not +// requested to stop. If the timers reset earlyReturn before early returning, +// then all but one of the timers will do work when only one was expected to. +// If CancellableTimer resets earlyReturn when resetting, then all the timers +// will fire (again, when only one was expected to). +// +// To address the above concerns the simplest solution was to give each timer +// its own earlyReturn signal. +type cancellableTimerInstance struct { + timer *time.Timer + + // Used to inform the timer to early return when it gets stopped while the + // lock the timer tries to obtain when fired is held (T1 is a goroutine that + // tries to cancel the timer and T2 is the goroutine that handles the timer + // firing): + // T1: Obtain the lock, then call StopLocked() + // T2: timer fires, and gets blocked on obtaining the lock + // T1: Releases lock + // T2: Obtains lock does unintended work + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using earlyReturn to return early so that once T2 obtains + // the lock, it will see that it is set to true and do nothing further. + earlyReturn *bool +} + +// stop stops the timer instance t from firing if it hasn't fired already. If it +// has fired and is blocked at obtaining the lock, earlyReturn will be set to +// true so that it will early return when it obtains the lock. +func (t *cancellableTimerInstance) stop() { + if t.timer != nil { + t.timer.Stop() + *t.earlyReturn = true + } +} + +// CancellableTimer is a timer that does some work and can be safely cancelled +// when it fires at the same time some "related work" is being done. +// +// The term "related work" is defined as some work that needs to be done while +// holding some lock that the timer must also hold while doing some work. +type CancellableTimer struct { + // The active instance of a cancellable timer. + instance cancellableTimerInstance + + // locker is the lock taken by the timer immediately after it fires and must + // be held when attempting to stop the timer. + // + // Must never change after being assigned. + locker sync.Locker + + // fn is the function that will be called when a timer fires and has not been + // signaled to early return. + // + // fn MUST NOT attempt to lock locker. + // + // Must never change after being assigned. + fn func() +} + +// StopLocked prevents the Timer from firing if it has not fired already. +// +// If the timer is blocked on obtaining the t.locker lock when StopLocked is +// called, it will early return instead of calling t.fn. +// +// Note, t will be modified. +// +// t.locker MUST be locked. +func (t *CancellableTimer) StopLocked() { + t.instance.stop() + + // Nothing to do with the stopped instance anymore. + t.instance = cancellableTimerInstance{} +} + +// Reset changes the timer to expire after duration d. +// +// Note, t will be modified. +// +// Reset should only be called on stopped or expired timers. To be safe, callers +// should always call StopLocked before calling Reset. +func (t *CancellableTimer) Reset(d time.Duration) { + // Create a new instance. + earlyReturn := false + t.instance = cancellableTimerInstance{ + timer: time.AfterFunc(d, func() { + t.locker.Lock() + defer t.locker.Unlock() + + if earlyReturn { + // If we reach this point, it means that the timer fired while another + // goroutine called StopLocked while it had the lock. Simply return + // here and do nothing further. + earlyReturn = false + return + } + + t.fn() + }), + earlyReturn: &earlyReturn, + } +} + +// MakeCancellableTimer returns an unscheduled CancellableTimer with the given +// locker and fn. +// +// fn MUST NOT attempt to lock locker. +// +// Callers must call Reset to schedule the timer to fire. +func MakeCancellableTimer(locker sync.Locker, fn func()) CancellableTimer { + return CancellableTimer{locker: locker, fn: fn} +} diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go new file mode 100644 index 000000000..1f735d735 --- /dev/null +++ b/pkg/tcpip/timer_test.go @@ -0,0 +1,236 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package timer_test + +import ( + "sync" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + shortDuration = 1 * time.Nanosecond + middleDuration = 100 * time.Millisecond + longDuration = 1 * time.Second +) + +func TestCancellableTimerFire(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + timer := tcpip.MakeCancellableTimer(&lock, func() { + ch <- struct{}{} + }) + timer.Reset(shortDuration) + + // Wait for timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestCancellableTimerResetFromLongDuration(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(middleDuration) + + lock.Lock() + timer.StopLocked() + lock.Unlock() + + timer.Reset(shortDuration) + + // Wait for timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestCancellableTimerResetFromShortDuration(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + lock.Lock() + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(shortDuration) + timer.StopLocked() + lock.Unlock() + + // Wait for timer to fire if it wasn't correctly stopped. + select { + case <-ch: + t.Fatal("timer fired after being stopped") + case <-time.After(middleDuration): + } + + timer.Reset(shortDuration) + + // Wait for timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestCancellableTimerImmediatelyStop(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + for i := 0; i < 1000; i++ { + lock.Lock() + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(shortDuration) + timer.StopLocked() + lock.Unlock() + } + + // Wait for timer to fire if it wasn't correctly stopped. + select { + case <-ch: + t.Fatal("timer fired after being stopped") + case <-time.After(middleDuration): + } +} + +func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + lock.Lock() + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(shortDuration) + timer.StopLocked() + lock.Unlock() + + for i := 0; i < 10; i++ { + timer.Reset(middleDuration) + + lock.Lock() + // Sleep until the timer fires and gets blocked trying to take the lock. + time.Sleep(middleDuration * 2) + timer.StopLocked() + lock.Unlock() + } + + // Wait for double the duration so timers that weren't correctly stopped can + // fire. + select { + case <-ch: + t.Fatal("timer fired after being stopped") + case <-time.After(middleDuration * 2): + } +} + +func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + lock.Lock() + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(shortDuration) + for i := 0; i < 10; i++ { + // Sleep until the timer fires and gets blocked trying to take the lock. + time.Sleep(middleDuration) + timer.StopLocked() + timer.Reset(shortDuration) + } + lock.Unlock() + + // Wait for double the duration for the last timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestManyCancellableTimerResetUnderLock(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + lock.Lock() + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(shortDuration) + for i := 0; i < 10; i++ { + timer.StopLocked() + timer.Reset(shortDuration) + } + lock.Unlock() + + // Wait for double the duration for the last timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 9c40931b5..c7ce74cdd 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -350,13 +350,23 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +// SetSockOptBool sets a socket option. Currently not supported. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + return nil +} + // SetSockOptInt sets a socket option. Currently not supported. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 0010b5e5f..07ffa8aba 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -247,17 +247,17 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrNotSupported + return tcpip.ErrUnknownProtocolOption } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { - return 0, tcpip.ErrNotSupported +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -265,6 +265,16 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrNotSupported } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrNotSupported +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { + return 0, tcpip.ErrNotSupported +} + // HandlePacket implements stack.PacketEndpoint.HandlePacket. func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { ep.rcvMu.Lock() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 5aafe2615..85f7eb76b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -509,13 +509,38 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.KeepaliveEnabledOption: + *o = 0 + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -544,21 +569,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.ErrorOption: - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } -} - // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) { e.rcvMu.Lock() diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index dfaa4a559..4f361b226 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -391,9 +391,8 @@ func testV4Accept(t *testing.T, c *context.Context) { // Make sure we get the same error when calling the original ep and the // new one. This validates that v4-mapped endpoints are still able to // query the V6Only flag, whereas pure v4 endpoints are not. - var v tcpip.V6OnlyOption - expected := c.EP.GetSockOpt(&v) - if err := nep.GetSockOpt(&v); err != expected { + _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption) + if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected { t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected) } @@ -531,8 +530,7 @@ func TestV6AcceptOnV6(t *testing.T) { // Make sure we can still query the v6 only status of the new endpoint, // that is, that it is in fact a v6 socket. - var v tcpip.V6OnlyOption - if err := nep.GetSockOpt(&v); err != nil { + if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { t.Fatalf("GetSockOpt failed failed: %v", err) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 8ff125855..830bc1e3e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -885,8 +885,14 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // reject valid data that might already be in flight as the // acceptable window will shrink. if rcvWnd > e.rcvBufSize { + availBefore := e.receiveBufferAvailableLocked() e.rcvBufSize = rcvWnd - e.notifyProtocolGoroutine(notifyReceiveWindowChanged) + availAfter := e.receiveBufferAvailableLocked() + mask := uint32(notifyReceiveWindowChanged) + if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above { + mask |= notifyNonZeroReceiveWindow + } + e.notifyProtocolGoroutine(mask) } // We only update prevCopied when we grow the buffer because in cases @@ -956,11 +962,11 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { e.rcvBufUsed -= len(v) - // If the window was small before this read and if the read - // freed up enough buffer space, to either fit an aMSS or half - // a receive buffer (whichever smaller), then notify the - // protocol goroutine to send a window update. - if e.windowCrossedACKThreshold(len(v)) == 1 { + // If the window was small before this read and if the read freed up + // enough buffer space, to either fit an aMSS or half a receive buffer + // (whichever smaller), then notify the protocol goroutine to send a + // window update. + if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -1134,17 +1140,20 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro return num, tcpip.ControlMessages{}, nil } -// windowCrossedACKThreshold checks if the receive window to be announced now would be -// under aMSS or under half receive buffer, whichever smaller. This is useful as -// a receive side silly window syndrome prevention mechanism. If window grows -// to reasonable value, we should send ACK to the sender to inform the rx space is now -// large. We also want ensure a series of small read()'s won't trigger a flood of -// spurious tiny ACK's. +// windowCrossedACKThreshold checks if the receive window to be announced now +// would be under aMSS or under half receive buffer, whichever smaller. This is +// useful as a receive side silly window syndrome prevention mechanism. If +// window grows to reasonable value, we should send ACK to the sender to inform +// the rx space is now large. We also want ensure a series of small read()'s +// won't trigger a flood of spurious tiny ACK's. // -// For large receive buffers, the threshold is aMSS - once reader reads more than aMSS -// we'll send ACK. For tiny receive buffers, the threshold is half of receive buffer size. -// This is chosen arbitrairly. -func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) int { +// For large receive buffers, the threshold is aMSS - once reader reads more +// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of +// receive buffer size. This is chosen arbitrairly. +// crossed will be true if the window size crossed the ACK threshold. +// above will be true if the new window is >= ACK threshold and false +// otherwise. +func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) { newAvail := e.receiveBufferAvailableLocked() oldAvail := newAvail - deltaBefore if oldAvail < 0 { @@ -1158,15 +1167,38 @@ func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) int { switch { case oldAvail < threshold && newAvail >= threshold: - return 1 + return true, true case oldAvail >= threshold && newAvail < threshold: - return -1 + return true, false } - return 0 + return false, false +} + +// SetSockOptBool sets a socket option. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + // We only allow this to be set when we're in the initial state. + if e.state != StateInitial { + return tcpip.ErrInvalidEndpointState + } + + e.v6only = v + } + + return nil } // SetSockOptInt sets a socket option. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max @@ -1207,11 +1239,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { e.rcvAutoParams.disabled = true - // Immediatelly send an ACK to uncork the sender silly - // window syndrome prevetion, when our available space - // grows above aMSS or half receive buffer, whichever - // smaller. - if e.windowCrossedACKThreshold(availAfter-availBefore) == 1 { + // Immediately send an ACK to uncork the sender silly window + // syndrome prevetion, when our available space grows above aMSS + // or half receive buffer, whichever smaller. + if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above { mask |= notifyNonZeroReceiveWindow } e.rcvListMu.Unlock() @@ -1283,19 +1314,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil case tcpip.BindToDeviceOption: - e.mu.Lock() - defer e.mu.Unlock() - if v == "" { - e.bindToDevice = 0 - return nil - } - for nicID, nic := range e.stack.NICInfo() { - if nic.Name == string(v) { - e.bindToDevice = nicID - return nil - } + id := tcpip.NICID(v) + if id != 0 && !e.stack.HasNIC(id) { + return tcpip.ErrUnknownDevice } - return tcpip.ErrUnknownDevice + e.mu.Lock() + e.bindToDevice = id + e.mu.Unlock() + return nil case tcpip.QuickAckOption: if v == 0 { @@ -1316,23 +1342,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.notifyProtocolGoroutine(notifyMSSChanged) return nil - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrInvalidEndpointState - } - - e.mu.Lock() - defer e.mu.Unlock() - - // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { - return tcpip.ErrInvalidEndpointState - } - - e.v6only = v != 0 - return nil - case tcpip.TTLOption: e.mu.Lock() e.ttl = uint8(v) @@ -1473,8 +1482,27 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { return e.rcvBufUsed, nil } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrUnknownProtocolOption + } + + e.mu.Lock() + v := e.v6only + e.mu.Unlock() + + return v, nil + } + + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() @@ -1552,12 +1580,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.BindToDeviceOption: e.mu.RLock() - defer e.mu.RUnlock() - if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { - *o = tcpip.BindToDeviceOption(nic.Name) - return nil - } - *o = "" + *o = tcpip.BindToDeviceOption(e.bindToDevice) + e.mu.RUnlock() return nil case *tcpip.QuickAckOption: @@ -1567,22 +1591,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil - case *tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrUnknownProtocolOption - } - - e.mu.Lock() - v := e.v6only - e.mu.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil - case *tcpip.TTLOption: e.mu.Lock() *o = tcpip.TTLOption(e.ttl) @@ -2252,10 +2260,9 @@ func (e *endpoint) readyToRead(s *segment) { if s != nil { s.incRef() e.rcvBufUsed += s.data.Size() - // Increase counter if the receive window falls down - // below MSS or half receive buffer size, whichever - // smaller. - if e.windowCrossedACKThreshold(-s.data.Size()) == -1 { + // Increase counter if the receive window falls down below MSS + // or half receive buffer size, whichever smaller. + if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above { e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() } e.rcvList.PushBack(s) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 4c2e458e3..6edfa8dce 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1083,12 +1083,12 @@ func TestTrafficClassV6(t *testing.T) { func TestConnectBindToDevice(t *testing.T) { for _, test := range []struct { name string - device string + device tcpip.NICID want tcp.EndpointState }{ - {"RightDevice", "nic1", tcp.StateEstablished}, - {"WrongDevice", "nic2", tcp.StateSynSent}, - {"AnyDevice", "", tcp.StateEstablished}, + {"RightDevice", 1, tcp.StateEstablished}, + {"WrongDevice", 2, tcp.StateSynSent}, + {"AnyDevice", 0, tcp.StateEstablished}, } { t.Run(test.name, func(t *testing.T) { c := context.New(t, defaultMTU) @@ -3798,46 +3798,41 @@ func TestBindToDeviceOption(t *testing.T) { } defer ep.Close() - if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { - t.Errorf("CreateNamedNIC failed: %v", err) - } - - // Make an nameless NIC. - if err := s.CreateNIC(54321, loopback.New()); err != nil { + if err := s.CreateNIC(321, loopback.New()); err != nil { t.Errorf("CreateNIC failed: %v", err) } - // strPtr is used instead of taking the address of string literals, which is + // nicIDPtr is used instead of taking the address of NICID literals, which is // a compiler error. - strPtr := func(s string) *string { + nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { return &s } testActions := []struct { name string - setBindToDevice *string + setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error getBindToDevice tcpip.BindToDeviceOption }{ - {"GetDefaultValue", nil, nil, ""}, - {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, - {"BindToExistent", strPtr("my_device"), nil, "my_device"}, - {"UnbindToDevice", strPtr(""), nil, ""}, + {"GetDefaultValue", nil, nil, 0}, + {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToExistent", nicIDPtr(321), nil, 321}, + {"UnbindToDevice", nicIDPtr(0), nil, 0}, } for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") - if ep.GetSockOpt(&bindToDevice) != nil { - t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + bindToDevice := tcpip.BindToDeviceOption(88888) + if err := ep.GetSockOpt(&bindToDevice); err != nil { + t.Errorf("GetSockOpt got %v, want %v", err, nil) } if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %q, want %q", got, want) + t.Errorf("bindToDevice got %d, want %d", got, want) } }) } @@ -4031,12 +4026,12 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { switch network { case "ipv4": case "ipv6": - if err := ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(1)) failed: %v", err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(V6OnlyOption(true)) failed: %v", err) } case "dual": - if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(0)) failed: %v", err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil { + t.Fatalf("SetSockOpt(V6OnlyOption(false)) failed: %v", err) } default: t.Fatalf("unknown network: '%s'", network) @@ -6614,16 +6609,14 @@ func TestIncreaseWindowOnReceive(t *testing.T) { t.Fatalf("expected small, non-zero window: %d", lastWnd) } - // We now have < 1 MSS in the buffer space. Read the data! An + // We now have < 1 MSS in the buffer space. Read the data! An // ack should be sent in response to that. The window was not // zero, but it grew to larger than MSS. - _, _, err := c.EP.Read(nil) - if err != nil { + if _, _, err := c.EP.Read(nil); err != nil { t.Fatalf("Read failed: %v", err) } - _, _, err = c.EP.Read(nil) - if err != nil { + if _, _, err := c.EP.Read(nil); err != nil { t.Fatalf("Read failed: %v", err) } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index b0a376eba..822907998 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -158,15 +158,17 @@ func New(t *testing.T, mtu uint32) *Context { if testing.Verbose() { wep = sniffer.New(ep) } - if err := s.CreateNamedNIC(1, "nic1", wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + opts := stack.NICOptions{Name: "nic1"} + if err := s.CreateNICWithOptions(1, wep, opts); err != nil { + t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) if testing.Verbose() { wep2 = sniffer.New(channel.New(1000, mtu, "")) } - if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + opts2 := stack.NICOptions{Name: "nic2"} + if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil { + t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) } if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { @@ -473,11 +475,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) { c.t.Fatalf("NewEndpoint failed: %v", err) } - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.EP.SetSockOpt(v); err != nil { + if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { c.t.Fatalf("SetSockOpt failed failed: %v", err) } } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 1ac4705af..864dc8733 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -456,14 +456,9 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { - return nil -} - -// SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -478,8 +473,20 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - e.v6only = v != 0 + e.v6only = v + } + + return nil +} +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { case tcpip.TTLOption: e.mu.Lock() e.ttl = uint8(v) @@ -624,19 +631,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() case tcpip.BindToDeviceOption: - e.mu.Lock() - defer e.mu.Unlock() - if v == "" { - e.bindToDevice = 0 - return nil - } - for nicID, nic := range e.stack.NICInfo() { - if nic.Name == string(v) { - e.bindToDevice = nicID - return nil - } + id := tcpip.NICID(v) + if id != 0 && !e.stack.HasNIC(id) { + return tcpip.ErrUnknownDevice } - return tcpip.ErrUnknownDevice + e.mu.Lock() + e.bindToDevice = id + e.mu.Unlock() + return nil case tcpip.BroadcastOption: e.mu.Lock() @@ -660,8 +662,27 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrUnknownProtocolOption + } + + e.mu.Lock() + v := e.v6only + e.mu.Unlock() + + return v, nil + } + + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -695,22 +716,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrUnknownProtocolOption - } - - e.mu.Lock() - v := e.v6only - e.mu.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil - case *tcpip.TTLOption: e.mu.Lock() *o = tcpip.TTLOption(e.ttl) @@ -757,12 +762,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.BindToDeviceOption: e.mu.RLock() - defer e.mu.RUnlock() - if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { - *o = tcpip.BindToDeviceOption(nic.Name) - return nil - } - *o = tcpip.BindToDeviceOption("") + *o = tcpip.BindToDeviceOption(e.bindToDevice) + e.mu.RUnlock() return nil case *tcpip.KeepaliveEnabledOption: diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 7051a7a9c..0a82bc4fa 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -335,7 +335,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { c.createEndpoint(flow.sockProto()) if flow.isV6Only() { - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { c.t.Fatalf("SetSockOpt failed: %v", err) } } else if flow.isBroadcast() { @@ -508,46 +508,42 @@ func TestBindToDeviceOption(t *testing.T) { } defer ep.Close() - if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { - t.Errorf("CreateNamedNIC failed: %v", err) + opts := stack.NICOptions{Name: "my_device"} + if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { + t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } - // Make an nameless NIC. - if err := s.CreateNIC(54321, loopback.New()); err != nil { - t.Errorf("CreateNIC failed: %v", err) - } - - // strPtr is used instead of taking the address of string literals, which is + // nicIDPtr is used instead of taking the address of NICID literals, which is // a compiler error. - strPtr := func(s string) *string { + nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { return &s } testActions := []struct { name string - setBindToDevice *string + setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error getBindToDevice tcpip.BindToDeviceOption }{ - {"GetDefaultValue", nil, nil, ""}, - {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, - {"BindToExistent", strPtr("my_device"), nil, "my_device"}, - {"UnbindToDevice", strPtr(""), nil, ""}, + {"GetDefaultValue", nil, nil, 0}, + {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToExistent", nicIDPtr(321), nil, 321}, + {"UnbindToDevice", nicIDPtr(0), nil, 0}, } for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") - if ep.GetSockOpt(&bindToDevice) != nil { - t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + bindToDevice := tcpip.BindToDeviceOption(88888) + if err := ep.GetSockOpt(&bindToDevice); err != nil { + t.Errorf("GetSockOpt got %v, want %v", err, nil) } if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %q, want %q", got, want) + t.Errorf("bindToDevice got %d, want %d", got, want) } }) } |