diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/BUILD | 152 | ||||
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state_test.go | 61 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 790 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 5589 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache_test.go | 1584 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry_list.go | 221 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry_test.go | 2269 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic_test.go | 224 | ||||
-rw-r--r-- | pkg/tcpip/stack/nud_test.go | 816 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer_list.go | 221 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer_test.go | 649 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_state_autogen.go | 1287 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 4543 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_unsafe_state_autogen.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer_test.go | 446 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 555 | ||||
-rw-r--r-- | pkg/tcpip/stack/tuple_list.go | 221 |
17 files changed, 1953 insertions, 17678 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD deleted file mode 100644 index 395ff9a07..000000000 --- a/pkg/tcpip/stack/BUILD +++ /dev/null @@ -1,152 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test", "most_shards") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "neighbor_entry_list", - out = "neighbor_entry_list.go", - package = "stack", - prefix = "neighborEntry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*neighborEntry", - "Linker": "*neighborEntry", - }, -) - -go_template_instance( - name = "packet_buffer_list", - out = "packet_buffer_list.go", - package = "stack", - prefix = "PacketBuffer", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*PacketBuffer", - "Linker": "*PacketBuffer", - }, -) - -go_template_instance( - name = "tuple_list", - out = "tuple_list.go", - package = "stack", - prefix = "tuple", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*tuple", - "Linker": "*tuple", - }, -) - -go_library( - name = "stack", - srcs = [ - "addressable_endpoint_state.go", - "conntrack.go", - "headertype_string.go", - "hook_string.go", - "icmp_rate_limit.go", - "iptables.go", - "iptables_state.go", - "iptables_targets.go", - "iptables_types.go", - "neighbor_cache.go", - "neighbor_entry.go", - "neighbor_entry_list.go", - "neighborstate_string.go", - "nic.go", - "nic_stats.go", - "nud.go", - "packet_buffer.go", - "packet_buffer_list.go", - "packet_buffer_unsafe.go", - "pending_packets.go", - "rand.go", - "registration.go", - "route.go", - "stack.go", - "stack_global_state.go", - "stack_options.go", - "tcp.go", - "transport_demuxer.go", - "tuple_list.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/atomicbitops", - "//pkg/buffer", - "//pkg/ilist", - "//pkg/log", - "//pkg/rand", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/hash/jenkins", - "//pkg/tcpip/header", - "//pkg/tcpip/ports", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/transport/tcpconntrack", - "//pkg/waiter", - "@org_golang_x_time//rate:go_default_library", - ], -) - -go_test( - name = "stack_x_test", - size = "medium", - srcs = [ - "addressable_endpoint_state_test.go", - "ndp_test.go", - "nud_test.go", - "stack_test.go", - "transport_demuxer_test.go", - "transport_test.go", - ], - shard_count = most_shards, - deps = [ - ":stack", - "//pkg/rand", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/ports", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "stack_test", - size = "small", - srcs = [ - "forwarding_test.go", - "neighbor_cache_test.go", - "neighbor_entry_test.go", - "nic_test.go", - "packet_buffer_test.go", - ], - library = ":stack", - deps = [ - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/testutil", - "@com_github_google_go_cmp//cmp:go_default_library", - "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", - ], -) diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go deleted file mode 100644 index 140f146f6..000000000 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -// TestAddressableEndpointStateCleanup tests that cleaning up an addressable -// endpoint state removes permanent addresses and leaves groups. -func TestAddressableEndpointStateCleanup(t *testing.T) { - var ep fakeNetworkEndpoint - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - var s stack.AddressableEndpointState - s.Init(&ep) - - addr := tcpip.AddressWithPrefix{ - Address: "\x01", - PrefixLen: 8, - } - - { - ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) - if err != nil { - t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err) - } - // We don't need the address endpoint. - ep.DecRef() - } - { - ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) - if ep == nil { - t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = nil, want = non-nil", addr.Address) - } - ep.DecRef() - } - - s.Cleanup() - if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil { - ep.DecRef() - t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) - } -} diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go deleted file mode 100644 index 72f66441f..000000000 --- a/pkg/tcpip/stack/forwarding_test.go +++ /dev/null @@ -1,790 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "encoding/binary" - "math" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -const ( - fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - fwdTestNetHeaderLen = 12 - fwdTestNetDefaultPrefixLen = 8 - - // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, - // except where another value is explicitly used. It is chosen to match - // the MTU of loopback interfaces on linux systems. - fwdTestNetDefaultMTU = 65536 - - dstAddrOffset = 0 - srcAddrOffset = 1 - protocolNumberOffset = 2 -) - -var _ LinkAddressResolver = (*fwdTestNetworkEndpoint)(nil) -var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) - -// fwdTestNetworkEndpoint is a network-layer protocol endpoint. -// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only -// use the first three: destination address, source address, and transport -// protocol. They're all one byte fields to simplify parsing. -type fwdTestNetworkEndpoint struct { - AddressableEndpointState - - nic NetworkInterface - proto *fwdTestNetworkProtocol - dispatcher TransportDispatcher - - mu struct { - sync.RWMutex - forwarding bool - } -} - -func (*fwdTestNetworkEndpoint) Enable() tcpip.Error { - return nil -} - -func (*fwdTestNetworkEndpoint) Enabled() bool { - return true -} - -func (*fwdTestNetworkEndpoint) Disable() {} - -func (f *fwdTestNetworkEndpoint) MTU() uint32 { - return f.nic.MTU() - uint32(f.MaxHeaderLength()) -} - -func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { - return 123 -} - -func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { - if _, _, ok := f.proto.Parse(pkt); !ok { - return - } - - netHdr := pkt.NetworkHeader().View() - _, dst := f.proto.ParseAddresses(netHdr) - - addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), CanBePrimaryEndpoint) - if addressEndpoint != nil { - addressEndpoint.DecRef() - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]), pkt) - return - } - - r, err := f.proto.stack.FindRoute(0, "", dst, fwdTestNetNumber, false /* multicastLoop */) - if err != nil { - return - } - defer r.Release() - - vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) - pkt = NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: vv.ToView().ToVectorisedView(), - }) - // TODO(gvisor.dev/issue/1085) Decrease the TTL field in forwarded packets. - _ = r.WriteHeaderIncludedPacket(pkt) -} - -func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { - return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen -} - -func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return f.proto.Number() -} - -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { - // Add the protocol's header to the packet and send it to the link - // endpoint. - b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) - b[dstAddrOffset] = r.RemoteAddress()[0] - b[srcAddrOffset] = r.LocalAddress()[0] - b[protocolNumberOffset] = byte(params.Protocol) - - return f.nic.WritePacket(r, fwdTestNetNumber, pkt) -} - -// WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { - panic("not implemented") -} - -func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error { - // The network header should not already be populated. - if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok { - return &tcpip.ErrMalformedHeader{} - } - - return f.nic.WritePacket(r, fwdTestNetNumber, pkt) -} - -func (f *fwdTestNetworkEndpoint) Close() { - f.AddressableEndpointState.Cleanup() -} - -// Stats implements stack.NetworkEndpoint. -func (*fwdTestNetworkEndpoint) Stats() NetworkEndpointStats { - return &fwdTestNetworkEndpointStats{} -} - -var _ NetworkEndpointStats = (*fwdTestNetworkEndpointStats)(nil) - -type fwdTestNetworkEndpointStats struct{} - -// IsNetworkEndpointStats implements stack.NetworkEndpointStats. -func (*fwdTestNetworkEndpointStats) IsNetworkEndpointStats() {} - -var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) - -// fwdTestNetworkProtocol is a network-layer protocol that implements Address -// resolution. -type fwdTestNetworkProtocol struct { - stack *Stack - - neigh *neighborCache - addrResolveDelay time.Duration - onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress) - onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) -} - -func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { - return fwdTestNetNumber -} - -func (*fwdTestNetworkProtocol) MinimumPacketSize() int { - return fwdTestNetHeaderLen -} - -func (*fwdTestNetworkProtocol) DefaultPrefixLen() int { - return fwdTestNetDefaultPrefixLen -} - -func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) -} - -func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen) - if !ok { - return 0, false, false - } - return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true -} - -func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, dispatcher TransportDispatcher) NetworkEndpoint { - e := &fwdTestNetworkEndpoint{ - nic: nic, - proto: f, - dispatcher: dispatcher, - } - e.AddressableEndpointState.Init(e) - return e -} - -func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} -} - -func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} -} - -func (*fwdTestNetworkProtocol) Close() {} - -func (*fwdTestNetworkProtocol) Wait() {} - -func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { - if fn := f.proto.onLinkAddressResolved; fn != nil { - f.proto.stack.clock.AfterFunc(f.proto.addrResolveDelay, func() { - fn(f.proto.neigh, addr, remoteLinkAddr) - }) - } - return nil -} - -func (f *fwdTestNetworkEndpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if fn := f.proto.onResolveStaticAddress; fn != nil { - return fn(addr) - } - return "", false -} - -func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return fwdTestNetNumber -} - -// Forwarding implements stack.ForwardingNetworkEndpoint. -func (f *fwdTestNetworkEndpoint) Forwarding() bool { - f.mu.RLock() - defer f.mu.RUnlock() - return f.mu.forwarding - -} - -// SetForwarding implements stack.ForwardingNetworkEndpoint. -func (f *fwdTestNetworkEndpoint) SetForwarding(v bool) { - f.mu.Lock() - defer f.mu.Unlock() - f.mu.forwarding = v -} - -// fwdTestPacketInfo holds all the information about an outbound packet. -type fwdTestPacketInfo struct { - RemoteLinkAddress tcpip.LinkAddress - LocalLinkAddress tcpip.LinkAddress - Pkt *PacketBuffer -} - -var _ LinkEndpoint = (*fwdTestLinkEndpoint)(nil) - -type fwdTestLinkEndpoint struct { - dispatcher NetworkDispatcher - mtu uint32 - linkAddr tcpip.LinkAddress - - // C is where outbound packets are queued. - C chan fwdTestPacketInfo -} - -// InjectInbound injects an inbound packet. -func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - e.InjectLinkAddr(protocol, "", pkt) -} - -// InjectLinkAddr injects an inbound packet with a remote link address. -func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) -} - -// Attach saves the stack network-layer dispatcher for use later when packets -// are injected. -func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) { - e.dispatcher = dispatcher -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *fwdTestLinkEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized -// during construction. -func (e *fwdTestLinkEndpoint) MTU() uint32 { - return e.mtu -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. -func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { - caps := LinkEndpointCapabilities(0) - return caps | CapabilityResolutionRequired -} - -// MaxHeaderLength returns the maximum size of the link layer header. Given it -// doesn't have a header, it just returns 0. -func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { - return 0 -} - -// LinkAddress returns the link address of this endpoint. -func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { - return e.linkAddr -} - -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { - p := fwdTestPacketInfo{ - RemoteLinkAddress: r.RemoteLinkAddress, - LocalLinkAddress: r.LocalLinkAddress, - Pkt: pkt, - } - - select { - case e.C <- p: - default: - } - - return nil -} - -// WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - n := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.WritePacket(r, protocol, pkt) - n++ - } - - return n, nil -} - -// Wait implements stack.LinkEndpoint.Wait. -func (*fwdTestLinkEndpoint) Wait() {} - -// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. -func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { - panic("not implemented") -} - -// AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *fwdTestLinkEndpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) { - panic("not implemented") -} - -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.ManualClock, *fwdTestLinkEndpoint, *fwdTestLinkEndpoint) { - clock := faketime.NewManualClock() - // Create a stack with the network protocol and two NICs. - s := New(Options{ - NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { - proto.stack = s - return proto - }}, - Clock: clock, - }) - - protoNum := proto.Number() - if err := s.SetForwardingDefaultAndAllNICs(protoNum, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", protoNum, err) - } - - // NIC 1 has the link address "a", and added the network address 1. - ep1 := &fwdTestLinkEndpoint{ - C: make(chan fwdTestPacketInfo, 300), - mtu: fwdTestNetDefaultMTU, - linkAddr: "a", - } - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC #1 failed:", err) - } - if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) - } - - // NIC 2 has the link address "b", and added the network address 2. - ep2 := &fwdTestLinkEndpoint{ - C: make(chan fwdTestPacketInfo, 300), - mtu: fwdTestNetDefaultMTU, - linkAddr: "b", - } - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC #2 failed:", err) - } - if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) - } - - nic, ok := s.nics[2] - if !ok { - t.Fatal("NIC 2 does not exist") - } - - if l, ok := nic.linkAddrResolvers[fwdTestNetNumber]; ok { - proto.neigh = &l.neigh - } - - // Route all packets to NIC 2. - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}}) - } - - return clock, ep1, ep2 -} - -func TestForwardingWithStaticResolver(t *testing.T) { - // Create a network protocol with a static resolver. - proto := &fwdTestNetworkProtocol{ - onResolveStaticAddress: - // The network address 3 is resolved to the link address "c". - func(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "\x03" { - return "c", true - } - return "", false - }, - } - - clock, ep1, ep2 := fwdTestNetFactory(t, proto) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - clock.Advance(proto.addrResolveDelay) - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } - - // Test that the static address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } -} - -func TestForwardingWithFakeResolver(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any address will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - clock, ep1, ep2 := fwdTestNetFactory(t, &proto) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - clock.Advance(proto.addrResolveDelay) - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } -} - -func TestForwardingWithNoResolver(t *testing.T) { - // Create a network protocol without a resolver. - proto := &fwdTestNetworkProtocol{} - - // Whether or not we use the neighbor cache here does not matter since - // neither linkAddrCache nor neighborCache will be used. - clock, ep1, ep2 := fwdTestNetFactory(t, proto) - - // inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - clock.Advance(proto.addrResolveDelay) - select { - case <-ep2.C: - t.Fatal("Packet should not be forwarded") - default: - } -} - -func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 50 * time.Millisecond, - onLinkAddressResolved: func(*neighborCache, tcpip.Address, tcpip.LinkAddress) { - // Don't resolve the link address. - }, - } - - clock, ep1, ep2 := fwdTestNetFactory(t, proto) - - const numPackets int = 5 - // These packets will all be enqueued in the packet queue to wait for link - // address resolution. - for i := 0; i < numPackets; i++ { - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - // All packets should fail resolution. - for i := 0; i < numPackets; i++ { - clock.Advance(proto.addrResolveDelay) - select { - case got := <-ep2.C: - t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got) - default: - } - } -} - -func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - } - }, - } - clock, ep1, ep2 := fwdTestNetFactory(t, &proto) - - // Inject an inbound packet to address 4 on NIC 1. This packet should - // not be forwarded. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 4 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf = buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - clock.Advance(proto.addrResolveDelay) - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } - - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } -} - -func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - clock, ep1, ep2 := fwdTestNetFactory(t, &proto) - - // Inject two inbound packets to address 3 on NIC 1. - for i := 0; i < 2; i++ { - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - for i := 0; i < 2; i++ { - var p fwdTestPacketInfo - - clock.Advance(proto.addrResolveDelay) - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } - - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } -} - -func TestForwardingWithFakeResolverManyPackets(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - clock, ep1, ep2 := fwdTestNetFactory(t, &proto) - - for i := 0; i < maxPendingPacketsPerResolution+5; i++ { - // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - // Set the packet sequence number. - binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - for i := 0; i < maxPendingPacketsPerResolution; i++ { - var p fwdTestPacketInfo - - clock.Advance(proto.addrResolveDelay) - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } - - b := PayloadSince(p.Pkt.NetworkHeader()) - if b[dstAddrOffset] != 3 { - t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) - } - if len(b) < fwdTestNetHeaderLen+2 { - t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) - } - seqNumBuf := b[fwdTestNetHeaderLen:] - - // The first 5 packets should not be forwarded so the sequence number should - // start with 5. - want := uint16(i + 5) - if n := binary.BigEndian.Uint16(seqNumBuf); n != want { - t.Fatalf("got the packet #%d, want = #%d", n, want) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } -} - -func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - clock, ep1, ep2 := fwdTestNetFactory(t, &proto) - - for i := 0; i < maxPendingResolutions+5; i++ { - // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. - // Each packet has a different destination address (3 to - // maxPendingResolutions + 7). - buf := buffer.NewView(30) - buf[dstAddrOffset] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - for i := 0; i < maxPendingResolutions; i++ { - var p fwdTestPacketInfo - - clock.Advance(proto.addrResolveDelay) - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } - - // The first 5 packets (address 3 to 7) should not be forwarded - // because their address resolutions are interrupted. - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } -} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go deleted file mode 100644 index 1407f497e..000000000 --- a/pkg/tcpip/stack/ndp_test.go +++ /dev/null @@ -1,5589 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "bytes" - "encoding/binary" - "fmt" - "math/rand" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - cryptorand "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -var ( - addr1 = testutil.MustParse6("a00::1") - addr2 = testutil.MustParse6("a00::2") - addr3 = testutil.MustParse6("a00::3") -) - -const ( - linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") - linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") - linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") - - defaultPrefixLen = 128 - - // Extra time to use when waiting for an async event to occur. - defaultAsyncPositiveEventTimeout = 10 * time.Second - - // Extra time to use when waiting for an async event to not occur. - // - // Since a negative check is used to make sure an event did not happen, it is - // okay to use a smaller timeout compared to the positive case since execution - // stall in regards to the monotonic clock will not affect the expected - // outcome. - defaultAsyncNegativeEventTimeout = time.Second -) - -var ( - llAddr1 = header.LinkLocalAddr(linkAddr1) - llAddr2 = header.LinkLocalAddr(linkAddr2) - llAddr3 = header.LinkLocalAddr(linkAddr3) - llAddr4 = header.LinkLocalAddr(linkAddr4) - dstAddr = tcpip.FullAddress{ - Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - Port: 25, - } -) - -func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix { - if !header.IsValidUnicastEthernetAddress(linkAddr) { - return tcpip.AddressWithPrefix{} - } - - addrBytes := []byte(subnet.ID()) - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) - return tcpip.AddressWithPrefix{ - Address: tcpip.Address(addrBytes), - PrefixLen: 64, - } -} - -// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent -// tcpip.Subnet, and an address where the lower half of the address is composed -// of the EUI-64 of linkAddr if it is a valid unicast ethernet address. -func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) { - prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0} - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address(prefixBytes), - PrefixLen: 64, - } - - subnet := prefix.Subnet() - - return prefix, subnet, addrForSubnet(subnet, linkAddr) -} - -// ndpDADEvent is a set of parameters that was passed to -// ndpDispatcher.OnDuplicateAddressDetectionResult. -type ndpDADEvent struct { - nicID tcpip.NICID - addr tcpip.Address - res stack.DADResult -} - -type ndpOffLinkRouteEvent struct { - nicID tcpip.NICID - subnet tcpip.Subnet - router tcpip.Address - prf header.NDPRoutePreference - // true if route was updated, false if invalidated. - updated bool -} - -type ndpPrefixEvent struct { - nicID tcpip.NICID - prefix tcpip.Subnet - // true if prefix was discovered, false if invalidated. - discovered bool -} - -type ndpAutoGenAddrEventType int - -const ( - newAddr ndpAutoGenAddrEventType = iota - deprecatedAddr - invalidatedAddr -) - -type ndpAutoGenAddrEvent struct { - nicID tcpip.NICID - addr tcpip.AddressWithPrefix - eventType ndpAutoGenAddrEventType -} - -type ndpRDNSS struct { - addrs []tcpip.Address - lifetime time.Duration -} - -type ndpRDNSSEvent struct { - nicID tcpip.NICID - rdnss ndpRDNSS -} - -type ndpDNSSLEvent struct { - nicID tcpip.NICID - domainNames []string - lifetime time.Duration -} - -type ndpDHCPv6Event struct { - nicID tcpip.NICID - configuration ipv6.DHCPv6ConfigurationFromNDPRA -} - -var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) - -// ndpDispatcher implements NDPDispatcher so tests can know when various NDP -// related events happen for test purposes. -type ndpDispatcher struct { - dadC chan ndpDADEvent - offLinkRouteC chan ndpOffLinkRouteEvent - prefixC chan ndpPrefixEvent - autoGenAddrC chan ndpAutoGenAddrEvent - rdnssC chan ndpRDNSSEvent - dnsslC chan ndpDNSSLEvent - routeTable []tcpip.Route - dhcpv6ConfigurationC chan ndpDHCPv6Event -} - -// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionResult. -func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) { - if n.dadC != nil { - n.dadC <- ndpDADEvent{ - nicID, - addr, - res, - } - } -} - -// Implements ipv6.NDPDispatcher.OnOffLinkRouteUpdated. -func (n *ndpDispatcher) OnOffLinkRouteUpdated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference) { - if c := n.offLinkRouteC; c != nil { - c <- ndpOffLinkRouteEvent{ - nicID, - subnet, - router, - prf, - true, - } - } -} - -// Implements ipv6.NDPDispatcher.OnOffLinkRouteInvalidated. -func (n *ndpDispatcher) OnOffLinkRouteInvalidated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address) { - if c := n.offLinkRouteC; c != nil { - var prf header.NDPRoutePreference - c <- ndpOffLinkRouteEvent{ - nicID, - subnet, - router, - prf, - false, - } - } -} - -// Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered. -func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) { - if c := n.prefixC; c != nil { - c <- ndpPrefixEvent{ - nicID, - prefix, - true, - } - } -} - -// Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated. -func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) { - if c := n.prefixC; c != nil { - c <- ndpPrefixEvent{ - nicID, - prefix, - false, - } - } -} - -func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { - if c := n.autoGenAddrC; c != nil { - c <- ndpAutoGenAddrEvent{ - nicID, - addr, - newAddr, - } - } -} - -func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { - if c := n.autoGenAddrC; c != nil { - c <- ndpAutoGenAddrEvent{ - nicID, - addr, - deprecatedAddr, - } - } -} - -func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { - if c := n.autoGenAddrC; c != nil { - c <- ndpAutoGenAddrEvent{ - nicID, - addr, - invalidatedAddr, - } - } -} - -// Implements ipv6.NDPDispatcher.OnRecursiveDNSServerOption. -func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { - if c := n.rdnssC; c != nil { - c <- ndpRDNSSEvent{ - nicID, - ndpRDNSS{ - addrs, - lifetime, - }, - } - } -} - -// Implements ipv6.NDPDispatcher.OnDNSSearchListOption. -func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) { - if n.dnsslC != nil { - n.dnsslC <- ndpDNSSLEvent{ - nicID, - domainNames, - lifetime, - } - } -} - -// Implements ipv6.NDPDispatcher.OnDHCPv6Configuration. -func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration ipv6.DHCPv6ConfigurationFromNDPRA) { - if c := n.dhcpv6ConfigurationC; c != nil { - c <- ndpDHCPv6Event{ - nicID, - configuration, - } - } -} - -// channelLinkWithHeaderLength is a channel.Endpoint with a configurable -// header length. -type channelLinkWithHeaderLength struct { - *channel.Endpoint - headerLength uint16 -} - -func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { - return l.headerLength -} - -// Check e to make sure that the event is for addr on nic with ID 1, and the -// resolved flag set to resolved with the specified err. -func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) string { - return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, res: res}, e, cmp.AllowUnexported(e)) -} - -// TestDADDisabled tests that an address successfully resolves immediately -// when DAD is not enabled (the default for an empty stack.Options). -func TestDADDisabled(t *testing.T) { - const nicID = 1 - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - })}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr1, - PrefixLen: defaultPrefixLen, - } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) - } - - // Should get the address immediately since we should not have performed - // DAD on it. - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatal(err) - } - - // We should not have sent any NDP NS messages. - if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 { - t.Fatalf("got NeighborSolicit = %d, want = 0", got) - } -} - -func TestDADResolveLoopback(t *testing.T) { - const nicID = 1 - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - - dadConfigs := stack.DADConfigurations{ - RetransmitTimer: time.Second, - DupAddrDetectTransmits: 1, - } - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - Clock: clock, - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - DADConfigs: dadConfigs, - })}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr1, - PrefixLen: defaultPrefixLen, - } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) - } - - // Address should not be considered bound to the NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // DAD should not resolve after the normal resolution time since our DAD - // message was looped back - we should extend our DAD process. - dadResolutionTime := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer - clock.Advance(dadResolutionTime) - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Error(err) - } - - // Make sure the address does not resolve before the extended resolution time - // has passed. - const delta = time.Nanosecond - // DAD will send extra NS probes if an NS message is looped back. - const extraTransmits = 3 - clock.Advance(dadResolutionTime*extraTransmits - delta) - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Error(err) - } - - // DAD should now resolve. - clock.Advance(delta) - if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr1, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } -} - -// TestDADResolve tests that an address successfully resolves after performing -// DAD for various values of DupAddrDetectTransmits and RetransmitTimer. -// Included in the subtests is a test to make sure that an invalid -// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. -// This tests also validates the NDP NS packet that is transmitted. -func TestDADResolve(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - linkHeaderLen uint16 - dupAddrDetectTransmits uint8 - retransTimer time.Duration - expectedRetransmitTimer time.Duration - }{ - { - name: "1:1s:1s", - dupAddrDetectTransmits: 1, - retransTimer: time.Second, - expectedRetransmitTimer: time.Second, - }, - { - name: "2:1s:1s", - linkHeaderLen: 1, - dupAddrDetectTransmits: 2, - retransTimer: time.Second, - expectedRetransmitTimer: time.Second, - }, - { - name: "1:2s:2s", - linkHeaderLen: 2, - dupAddrDetectTransmits: 1, - retransTimer: 2 * time.Second, - expectedRetransmitTimer: 2 * time.Second, - }, - // 0s is an invalid RetransmitTimer timer and will be fixed to - // the default RetransmitTimer value of 1s. - { - name: "1:0s:1s", - linkHeaderLen: 3, - dupAddrDetectTransmits: 1, - retransTimer: 0, - expectedRetransmitTimer: time.Second, - }, - } - - nonces := [][]byte{ - {1, 2, 3, 4, 5, 6}, - {7, 8, 9, 10, 11, 12}, - } - - var secureRNGBytes []byte - for _, n := range nonces { - secureRNGBytes = append(secureRNGBytes, n...) - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1), - headerLength: test.linkHeaderLen, - } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - var secureRNG bytes.Reader - secureRNG.Reset(secureRNGBytes) - - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - Clock: clock, - RandSource: rand.NewSource(time.Now().UnixNano()), - SecureRNG: &secureRNG, - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - DADConfigs: stack.DADConfigurations{ - RetransmitTimer: test.retransTimer, - DupAddrDetectTransmits: test.dupAddrDetectTransmits, - }, - })}, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // We add a default route so the call to FindRoute below will succeed - // once we have an assigned address. - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: addr3, - NIC: nicID, - }}) - - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr1, - PrefixLen: defaultPrefixLen, - } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) - } - - // Make sure the address does not resolve before the resolution time has - // passed. - const delta = time.Nanosecond - clock.Advance(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Error(err) - } - // Should not get a route even if we specify the local address as the - // tentative address. - { - r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{}) - } - if r != nil { - r.Release() - } - } - { - r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{}) - } - if r != nil { - r.Release() - } - } - - if t.Failed() { - t.FailNow() - } - - // Wait for DAD to resolve. - clock.Advance(delta) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatalf("expected DAD event for %s on NIC(%d)", addr1, nicID) - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Error(err) - } - // Should get a route using the address now that it is resolved. - { - r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) - if err != nil { - t.Errorf("got FindRoute(%d, '', %s, %d, false): %s", nicID, addr2, header.IPv6ProtocolNumber, err) - } else if r.LocalAddress() != addr1 { - t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), addr1) - } - r.Release() - } - { - r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) - if err != nil { - t.Errorf("got FindRoute(%d, %s, %s, %d, false): %s", nicID, addr1, addr2, header.IPv6ProtocolNumber, err) - } else if r.LocalAddress() != addr1 { - t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), addr1) - } - if r != nil { - r.Release() - } - } - - if t.Failed() { - t.FailNow() - } - - // Should not have sent any more NS messages. - if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { - t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) - } - - // Validate the sent Neighbor Solicitation messages. - for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p, ok := e.Read() - if !ok { - t.Fatal("packet didn't arrive") - } - - // Make sure its an IPv6 packet. - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - - // Make sure the right remote link address is used. - snmc := header.SolicitedNodeAddr(addr1) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } - - // Check NDP NS packet. - // - // As per RFC 4861 section 4.3, a possible option is the Source Link - // Layer option, but this option MUST NOT be included when the source - // address of the packet is the unspecified address. - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(addr1), - checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[i])}), - )) - - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - }) - } -} - -func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.MessageBody()) - ns.SetTargetAddress(tgt) - snmc := header.SolicitedNodeAddr(tgt) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: header.IPv6Any, - Dst: snmc, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: 255, - SrcAddr: header.IPv6Any, - DstAddr: snmc, - }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) -} - -// TestDADFail tests to make sure that the DAD process fails if another node is -// detected to be performing DAD on the same address (receive an NS message from -// a node doing DAD for the same address), or if another node is detected to own -// the address already (receive an NA message for the tentative address). -func TestDADFail(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - rxPkt func(e *channel.Endpoint, tgt tcpip.Address) - getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - expectedHolderLinkAddress tcpip.LinkAddress - }{ - { - name: "RxSolicit", - rxPkt: rxNDPSolicit, - getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return s.NeighborSolicit - }, - expectedHolderLinkAddress: "", - }, - { - name: "RxAdvert", - rxPkt: func(e *channel.Endpoint, tgt tcpip.Address) { - naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) - pkt := header.ICMPv6(hdr.Prepend(naSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(tgt) - na.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: tgt, - Dst: header.IPv6AllNodesMulticastAddress, - })) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: 255, - SrcAddr: tgt, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) - }, - getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return s.NeighborAdvert - }, - expectedHolderLinkAddress: linkAddr1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - dadConfigs := stack.DefaultDADConfigurations() - dadConfigs.RetransmitTimer = time.Second * 2 - - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - DADConfigs: dadConfigs, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } - - // Address should not be considered bound to the NIC yet - // (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Receive a packet to simulate an address conflict. - test.rxPkt(e, addr1) - - stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived) - if got := stat.Value(); got != 1 { - t.Fatalf("got stat = %d, want = 1", got) - } - - // Wait for DAD to fail and make sure the address did - // not get resolved. - clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, &stack.DADDupAddrDetected{HolderLinkAddress: test.expectedHolderLinkAddress}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - // If we don't get a failure event after the - // expected resolution time + extra 1s buffer, - // something is wrong. - t.Fatal("timed out waiting for DAD failure") - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Attempting to add the address again should not fail if the address's - // state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } - }) - } -} - -func TestDADStop(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - stopFn func(t *testing.T, s *stack.Stack) - skipFinalAddrCheck bool - }{ - // Tests to make sure that DAD stops when an address is removed. - { - name: "Remove address", - stopFn: func(t *testing.T, s *stack.Stack) { - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err) - } - }, - }, - - // Tests to make sure that DAD stops when the NIC is disabled. - { - name: "Disable NIC", - stopFn: func(t *testing.T, s *stack.Stack) { - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("DisableNIC(%d): %s", nicID, err) - } - }, - }, - - // Tests to make sure that DAD stops when the NIC is removed. - { - name: "Remove NIC", - stopFn: func(t *testing.T, s *stack.Stack) { - if err := s.RemoveNIC(nicID); err != nil { - t.Fatalf("RemoveNIC(%d): %s", nicID, err) - } - }, - // The NIC is removed so we can't check its addresses after calling - // stopFn. - skipFinalAddrCheck: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - - dadConfigs := stack.DADConfigurations{ - RetransmitTimer: time.Second, - DupAddrDetectTransmits: 2, - } - - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - DADConfigs: dadConfigs, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } - - // Address should not be considered bound to the NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - test.stopFn(t, s) - - // Wait for DAD to fail (since the address was removed during DAD). - clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, &stack.DADAborted{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - // If we don't get a failure event after the expected resolution - // time + extra 1s buffer, something is wrong. - t.Fatal("timed out waiting for DAD failure") - } - - if !test.skipFinalAddrCheck { - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - } - - // Should not have sent more than 1 NS message. - if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 { - t.Errorf("got NeighborSolicit = %d, want <= 1", got) - } - }) - } -} - -// TestSetNDPConfigurations tests that we can update and use per-interface NDP -// configurations without affecting the default NDP configurations or other -// interfaces' configurations. -func TestSetNDPConfigurations(t *testing.T) { - const nicID1 = 1 - const nicID2 = 2 - const nicID3 = 3 - - tests := []struct { - name string - dupAddrDetectTransmits uint8 - retransmitTimer time.Duration - expectedRetransmitTimer time.Duration - }{ - { - "OK", - 1, - time.Second, - time.Second, - }, - { - "Invalid Retransmit Timer", - 1, - 0, - time.Second, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatalf("expected DAD event for %s", addr) - } - } - - // This NIC(1)'s NDP configurations will be updated to - // be different from the default. - if err := s.CreateNIC(nicID1, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) - } - - // Created before updating NIC(1)'s NDP configurations - // but updating NIC(1)'s NDP configurations should not - // affect other existing NICs. - if err := s.CreateNIC(nicID2, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) - } - - // Update the configurations on NIC(1) to use DAD. - if ipv6Ep, err := s.GetNetworkEndpoint(nicID1, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, header.IPv6ProtocolNumber, err) - } else { - dad := ipv6Ep.(stack.DuplicateAddressDetector) - dad.SetDADConfigurations(stack.DADConfigurations{ - DupAddrDetectTransmits: test.dupAddrDetectTransmits, - RetransmitTimer: test.retransmitTimer, - }) - } - - // Created after updating NIC(1)'s NDP configurations - // but the stack's default NDP configurations should not - // have been updated. - if err := s.CreateNIC(nicID3, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID3, err) - } - - // Add addresses for each NIC. - addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err) - } - addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err) - } - expectDADEvent(nicID2, addr2) - addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err) - } - expectDADEvent(nicID3, addr3) - - // Address should not be considered bound to NIC(1) yet - // (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Should get the address on NIC(2) and NIC(3) - // immediately since we should not have performed DAD on - // it as the stack was configured to not do DAD by - // default and we only updated the NDP configurations on - // NIC(1). - if err := checkGetMainNICAddress(s, nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { - t.Fatal(err) - } - if err := checkGetMainNICAddress(s, nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { - t.Fatal(err) - } - - // Sleep until right before resolution to make sure the address didn't - // resolve on NIC(1) yet. - const delta = 1 - clock.Advance(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) - if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Wait for DAD to resolve. - clock.Advance(delta) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID1, addr1, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for DAD resolution") - } - if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { - t.Fatal(err) - } - }) - } -} - -// raBuf returns a valid NDP Router Advertisement with options, router -// preference and DHCPv6 configurations specified. -func raBuf(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, prf header.NDPRoutePreference, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - const flagsByte = 1 - const routerLifetimeOffset = 2 - - icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + optSer.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(0) - raPayload := pkt.MessageBody() - ra := header.NDPRouterAdvert(raPayload) - // Populate the Router Lifetime. - binary.BigEndian.PutUint16(raPayload[routerLifetimeOffset:], rl) - // Populate the Managed Address flag field. - if managedAddress { - // The Managed Addresses flag field is the 7th bit of the flags byte. - raPayload[flagsByte] |= 1 << 7 - } - // Populate the Other Configurations flag field. - if otherConfigurations { - // The Other Configurations flag field is the 6th bit of the flags byte. - raPayload[flagsByte] |= 1 << 6 - } - // The Prf field is held in the flags byte. - raPayload[flagsByte] |= byte(prf) << 3 - opts := ra.Options() - opts.Serialize(optSer) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: ip, - Dst: header.IPv6AllNodesMulticastAddress, - })) - payloadLength := hdr.UsedLength() - iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - iph.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: header.NDPHopLimit, - SrcAddr: ip, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) - - return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) -} - -// raBufWithOpts returns a valid NDP Router Advertisement with options. -// -// Note, raBufWithOpts does not populate any of the RA fields other than the -// Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, 0 /* prf */, optSer) -} - -// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related -// fields set. -// -// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the -// DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfigurations bool) *stack.PacketBuffer { - return raBuf(ip, 0, managedAddresses, otherConfigurations, 0 /* prf */, header.NDPOptionsSerializer{}) -} - -// raBuf returns a valid NDP Router Advertisement. -// -// Note, raBuf does not populate any of the RA fields other than the -// Router Lifetime. -func raBufSimple(ip tcpip.Address, rl uint16) *stack.PacketBuffer { - return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) -} - -// raBufWithPrf returns a valid NDP Router Advertisement with a preference. -// -// Note, raBufWithPrf does not populate any of the RA fields other than the -// Router Lifetime and Default Router Preference fields. -func raBufWithPrf(ip tcpip.Address, rl uint16, prf header.NDPRoutePreference) *stack.PacketBuffer { - return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, prf, header.NDPOptionsSerializer{}) -} - -// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix -// Information option. -// -// Note, raBufWithPI does not populate any of the RA fields other than the -// Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer { - flags := uint8(0) - if onLink { - // The OnLink flag is the 7th bit in the flags byte. - flags |= 1 << 7 - } - if auto { - // The Address Auto-Configuration flag is the 6th bit in the - // flags byte. - flags |= 1 << 6 - } - - // A valid header.NDPPrefixInformation must be 30 bytes. - buf := [30]byte{} - // The first byte in a header.NDPPrefixInformation is the Prefix Length - // field. - buf[0] = uint8(prefix.PrefixLen) - // The 2nd byte within a header.NDPPrefixInformation is the Flags field. - buf[1] = flags - // The Valid Lifetime field starts after the 2nd byte within a - // header.NDPPrefixInformation. - binary.BigEndian.PutUint32(buf[2:], vl) - // The Preferred Lifetime field starts after the 6th byte within a - // header.NDPPrefixInformation. - binary.BigEndian.PutUint32(buf[6:], pl) - // The Prefix Address field starts after the 14th byte within a - // header.NDPPrefixInformation. - copy(buf[14:], prefix.Address) - return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{ - header.NDPPrefixInformation(buf[:]), - }) -} - -func TestDynamicConfigurationsDisabled(t *testing.T) { - const ( - nicID = 1 - maxRtrSolicitDelay = time.Second - ) - - prefix := tcpip.AddressWithPrefix{ - Address: testutil.MustParse6("102:304:506:708::"), - PrefixLen: 64, - } - - tests := []struct { - name string - config func(bool) ipv6.NDPConfigurations - ra *stack.PacketBuffer - }{ - { - name: "No Router Discovery", - config: func(enable bool) ipv6.NDPConfigurations { - return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable} - }, - ra: raBufSimple(llAddr2, 1000), - }, - { - name: "No Prefix Discovery", - config: func(enable bool) ipv6.NDPConfigurations { - return ipv6.NDPConfigurations{DiscoverOnLinkPrefixes: enable} - }, - ra: raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0), - }, - { - name: "No Autogenerate Addresses", - config: func(enable bool) ipv6.NDPConfigurations { - return ipv6.NDPConfigurations{AutoGenGlobalAddresses: enable} - }, - ra: raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - // Being configured to discover routers/prefixes or auto-generate - // addresses means RAs must be handled, and router/prefix discovery or - // SLAAC must be enabled. - // - // This tests all possible combinations of the configurations where - // router/prefix discovery or SLAAC are disabled. - for i := 0; i < 7; i++ { - handle := ipv6.HandlingRAsDisabled - if i&1 != 0 { - handle = ipv6.HandlingRAsEnabledWhenForwardingDisabled - } - enable := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), - prefixC: make(chan ndpPrefixEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - ndpConfigs := test.config(enable) - ndpConfigs.HandleRAs = handle - ndpConfigs.MaxRtrSolicitations = 1 - ndpConfigs.RtrSolicitationInterval = maxRtrSolicitDelay - ndpConfigs.MaxRtrSolicitationDelay = maxRtrSolicitDelay - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - Clock: clock, - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - })}, - }) - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) - } - - e := channel.New(1, 1280, linkAddr1) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - handleRAsDisabled := handle == ipv6.HandlingRAsDisabled || forwarding - ep, err := s.GetNetworkEndpoint(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ipv6.ProtocolNumber, err) - } - stats := ep.Stats() - v6Stats, ok := stats.(*ipv6.Stats) - if !ok { - t.Fatalf("got v6Stats = %T, expected = %T", stats, v6Stats) - } - - // Make sure that when handling RAs are enabled, we solicit routers. - clock.Advance(maxRtrSolicitDelay) - if got, want := v6Stats.ICMP.PacketsSent.RouterSolicit.Value(), boolToUint64(!handleRAsDisabled); got != want { - t.Errorf("got v6Stats.ICMP.PacketsSent.RouterSolicit.Value() = %d, want = %d", got, want) - } - if handleRAsDisabled { - if p, ok := e.Read(); ok { - t.Errorf("unexpectedly got a packet = %#v", p) - } - } else if p, ok := e.Read(); !ok { - t.Error("expected router solicitation packet") - } else if p.Proto != header.IPv6ProtocolNumber { - t.Errorf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } else { - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(nil)), - ) - } - - // Make sure we do not discover any routers or prefixes, or perform - // SLAAC on reception of an RA. - e.InjectInbound(header.IPv6ProtocolNumber, test.ra.Clone()) - // Make sure that the unhandled RA stat is only incremented when - // handling RAs is disabled. - if got, want := v6Stats.UnhandledRouterAdvertisements.Value(), boolToUint64(handleRAsDisabled); got != want { - t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) - } - select { - case e := <-ndpDisp.offLinkRouteC: - t.Errorf("unexpectedly updated an off-link route when configured not to: %#v", e) - default: - } - select { - case e := <-ndpDisp.prefixC: - t.Errorf("unexpectedly discovered a prefix when configured not to: %#v", e) - default: - } - select { - case e := <-ndpDisp.autoGenAddrC: - t.Errorf("unexpectedly auto-generated an address when configured not to: %#v", e) - default: - } - }) - } - }) - } -} - -func boolToUint64(v bool) uint64 { - if v { - return 1 - } - return 0 -} - -func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, nicID tcpip.NICID, router tcpip.Address, prf header.NDPRoutePreference, updated bool) string { - return cmp.Diff(ndpOffLinkRouteEvent{nicID: nicID, subnet: header.IPv6EmptySubnet, router: router, prf: prf, updated: updated}, e, cmp.AllowUnexported(e)) -} - -func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) { - tests := [...]struct { - name string - handleRAs ipv6.HandleRAsConfiguration - forwarding bool - }{ - { - name: "Handle RAs when forwarding disabled", - handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - forwarding: false, - }, - { - name: "Always Handle RAs with forwarding disabled", - handleRAs: ipv6.HandlingRAsAlwaysEnabled, - forwarding: false, - }, - { - name: "Always Handle RAs with forwarding enabled", - handleRAs: ipv6.HandlingRAsAlwaysEnabled, - forwarding: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - f(t, test.handleRAs, test.forwarding) - }) - } -} - -func TestRouterDiscovery(t *testing.T) { - const nicID = 1 - - testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { - ndpDisp := ndpDispatcher{ - offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handleRAs, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - expectOffLinkRouteEvent := func(addr tcpip.Address, prf header.NDPRoutePreference, updated bool) { - t.Helper() - - select { - case e := <-ndpDisp.offLinkRouteC: - if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, updated); diff != "" { - t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } - } - - expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { - t.Helper() - - clock.Advance(timeout) - select { - case e := <-ndpDisp.offLinkRouteC: - var prf header.NDPRoutePreference - if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, false); diff != "" { - t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for router discovery event") - } - } - - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) - } - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - // Rx an RA from lladdr2 with zero lifetime. It should not be - // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0)) - select { - case <-ndpDisp.offLinkRouteC: - t.Fatal("unexpectedly updated an off-link route with 0 lifetime") - default: - } - - // Rx an RA from lladdr2 with a huge lifetime and reserved preference value - // (which should be interpreted as the default (medium) preference value). - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, 1000, header.ReservedRoutePreference)) - expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) - - // Rx an RA from another router (lladdr3) with non-zero lifetime and - // non-default preference value. - const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr3, l3LifetimeSeconds, header.HighRoutePreference)) - expectOffLinkRouteEvent(llAddr3, header.HighRoutePreference, true) - - // Rx an RA from lladdr2 with lesser lifetime and default (medium) - // preference value. - const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, l2LifetimeSeconds)) - select { - case <-ndpDisp.offLinkRouteC: - t.Fatal("should not receive a off-link route event when updating lifetimes for known routers") - default: - } - - // Rx an RA from lladdr2 with a different preference. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, l2LifetimeSeconds, header.LowRoutePreference)) - expectOffLinkRouteEvent(llAddr2, header.LowRoutePreference, true) - - // Wait for lladdr2's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second) - - // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 1000)) - expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) - - // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0)) - expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, false) - - // Wait for lladdr3's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second) - }) -} - -// TestRouterDiscoveryMaxRouters tests that only -// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered. -func TestRouterDiscoveryMaxRouters(t *testing.T) { - const nicID = 1 - - ndpDisp := ndpDispatcher{ - offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - // Receive an RA from 2 more than the max number of discovered routers. - for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ { - linkAddr := []byte{2, 2, 3, 4, 5, 0} - linkAddr[5] = byte(i) - llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) - - e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr, 5)) - - if i <= ipv6.MaxDiscoveredDefaultRouters { - select { - case e := <-ndpDisp.offLinkRouteC: - if diff := checkOffLinkRouteEvent(e, nicID, llAddr, header.MediumRoutePreference, true); diff != "" { - t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } - - } else { - select { - case <-ndpDisp.offLinkRouteC: - t.Fatal("should not have discovered a new router after we already discovered the max number of routers") - default: - } - } - } -} - -// Check e to make sure that the event is for prefix on nic with ID 1, and the -// discovered flag set to discovered. -func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { - return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) -} - -func TestPrefixDiscovery(t *testing.T) { - prefix1, subnet1, _ := prefixSubnetAddr(0, "") - prefix2, subnet2, _ := prefixSubnetAddr(1, "") - prefix3, subnet3, _ := prefixSubnetAddr(2, "") - - testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handleRAs, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { - t.Helper() - - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - } - - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix with 0 lifetime") - default: - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) - expectPrefixEvent(subnet1, true) - - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) - expectPrefixEvent(subnet2, true) - - // Receive an RA with prefix3 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) - expectPrefixEvent(subnet3, true) - - // Receive an RA with prefix1 in a PI with lifetime = 0. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - expectPrefixEvent(subnet1, false) - - // Receive an RA with prefix2 in a PI with lesser lifetime. - lifetime := uint32(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly received prefix event when updating lifetime") - default: - } - - // Wait for prefix2's most recent invalidation job plus some buffer to - // expire. - clock.Advance(time.Duration(lifetime) * time.Second) - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet2, false); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for prefix discovery event") - } - - // Receive RA to invalidate prefix3. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) - expectPrefixEvent(subnet3, false) - }) -} - -func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { - // Update the infinite lifetime value to a smaller value so we can test - // that when we receive a PI with such a lifetime value, we do not - // invalidate the prefix. - const testInfiniteLifetimeSeconds = 2 - const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second - saved := header.NDPInfiniteLifetime - header.NDPInfiniteLifetime = testInfiniteLifetime - defer func() { - header.NDPInfiniteLifetime = saved - }() - - prefix := tcpip.AddressWithPrefix{ - Address: testutil.MustParse6("102:304:506:708::"), - PrefixLen: 64, - } - subnet := prefix.Subnet() - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { - t.Helper() - - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - } - - // Receive an RA with prefix in an NDP Prefix Information option (PI) - // with infinite valid lifetime which should not get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) - expectPrefixEvent(subnet, true) - clock.Advance(testInfiniteLifetime) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - default: - } - - // Receive an RA with finite lifetime. - // The prefix should get invalidated after 1s. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) - clock.Advance(testInfiniteLifetime) - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet, false); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for prefix discovery event") - } - - // Receive an RA with finite lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) - expectPrefixEvent(subnet, true) - - // Receive an RA with prefix with an infinite lifetime. - // The prefix should not be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) - clock.Advance(testInfiniteLifetime) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - default: - } - - // Receive an RA with a prefix with a lifetime value greater than the - // set infinite lifetime value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) - clock.Advance((testInfiniteLifetimeSeconds + 1) * time.Second) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - default: - } - - // Receive an RA with 0 lifetime. - // The prefix should get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0)) - expectPrefixEvent(subnet, false) -} - -// TestPrefixDiscoveryMaxRouters tests that only -// ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. -func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - optSer := make(header.NDPOptionsSerializer, ipv6.MaxDiscoveredOnLinkPrefixes+2) - prefixes := [ipv6.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} - - // Receive an RA with 2 more than the max number of discovered on-link - // prefixes. - for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ { - prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0} - prefixAddr[7] = byte(i) - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address(prefixAddr[:]), - PrefixLen: 64, - } - prefixes[i] = prefix.Subnet() - buf := [30]byte{} - buf[0] = uint8(prefix.PrefixLen) - buf[1] = 128 - binary.BigEndian.PutUint32(buf[2:], 10) - copy(buf[14:], prefix.Address) - - optSer[i] = header.NDPPrefixInformation(buf[:]) - } - - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) - for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ { - if i < ipv6.MaxDiscoveredOnLinkPrefixes { - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - } else { - select { - case <-ndpDisp.prefixC: - t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes") - default: - } - } - } -} - -// Checks to see if list contains an IPv6 address, item. -func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: item, - } - - return containsAddr(list, protocolAddress) -} - -// Check e to make sure that the event is for addr on nic with ID 1, and the -// event type is set to eventType. -func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { - return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) -} - -// TestAutoGenAddr tests that an address is properly generated and invalidated -// when configured to do so. -func TestAutoGenAddr(t *testing.T) { - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handleRAs, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) - } - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with 0 lifetime") - default: - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") - default: - } - - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } - - // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") - default: - } - - // Wait for addr of prefix1 to be invalidated. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } - }) -} - -func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string { - ret := "" - for _, c := range containList { - if !containsV6Addr(addrs, c) { - ret += fmt.Sprintf("should have %s in the list of addresses\n", c) - } - } - for _, c := range notContainList { - if containsV6Addr(addrs, c) { - ret += fmt.Sprintf("should not have %s in the list of addresses\n", c) - } - } - return ret -} - -// TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when -// configured to do so as part of IPv6 Privacy Extensions. -func TestAutoGenTempAddr(t *testing.T) { - const ( - nicID = 1 - newMinVL = 5 - newMinVLDuration = newMinVL * time.Second - ) - - savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate - savedMaxDesync := ipv6.MaxDesyncFactor - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate - ipv6.MaxDesyncFactor = savedMaxDesync - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - ipv6.MaxDesyncFactor = time.Nanosecond - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - tests := []struct { - name string - dupAddrTransmits uint8 - retransmitTimer time.Duration - }{ - { - name: "DAD disabled", - }, - { - name: "DAD enabled", - dupAddrTransmits: 1, - retransmitTimer: time.Second, - }, - } - - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for i, test := range tests { - i := i - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - seed := []byte{uint8(i)} - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], seed, nicID) - newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { - return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) - } - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 2), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: test.dupAddrTransmits, - RetransmitTimer: test.retransmitTimer, - }, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - TempIIDSeed: seed, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } - - expectDADEventAsync := func(addr tcpip.Address) { - t.Helper() - - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") - } - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) - default: - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - expectDADEventAsync(addr1.Address) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) - default: - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid & preferred lifetimes. - tempAddr1 := newTempAddr(addr1.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(tempAddr1, newAddr) - expectDADEventAsync(tempAddr1.Address) - if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) - default: - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred - // lifetimes. - tempAddr2 := newTempAddr(addr2.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - expectDADEventAsync(addr2.Address) - expectAutoGenAddrEventAsync(tempAddr2, newAddr) - expectDADEventAsync(tempAddr2.Address) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Deprecate prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Refresh lifetimes for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Reduce valid lifetime and deprecate addresses of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Wait for addrs of prefix1 to be invalidated. They should be - // invalidated at the same time. - select { - case e := <-ndpDisp.autoGenAddrC: - var nextAddr tcpip.AddressWithPrefix - if e.addr == addr1 { - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = tempAddr1 - } else { - if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = addr1 - } - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) - } - - // Receive an RA with prefix2 in a PI w/ 0 lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unexpected auto gen addr event = %+v", e) - default: - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) - } - }) - } - }) -} - -// TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not -// generated for auto generated link-local addresses. -func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { - const nicID = 1 - - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - }() - ipv6.MaxDesyncFactor = time.Nanosecond - - tests := []struct { - name string - dupAddrTransmits uint8 - retransmitTimer time.Duration - }{ - { - name: "DAD disabled", - }, - { - name: "DAD enabled", - dupAddrTransmits: 1, - retransmitTimer: time.Second, - }, - } - - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - AutoGenLinkLocal: true, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // The stable link-local address should auto-generate and resolve DAD. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") - } - - // No new addresses should be generated. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unxpected auto gen addr event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - }) - } - }) -} - -// TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address -// will not be generated until after DAD completes, even if a new Router -// Advertisement is received to refresh lifetimes. -func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { - const ( - nicID = 1 - dadTransmits = 1 - retransmitTimer = 2 * time.Second - ) - - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - }() - ipv6.MaxDesyncFactor = 0 - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // Receive an RA to trigger SLAAC for prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - - // DAD on the stable address for prefix has not yet completed. Receiving a new - // RA that would refresh lifetimes should not generate a temporary SLAAC - // address for the prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - default: - } - - // Wait for DAD to complete for the stable address then expect the temporary - // address to be generated. - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") - } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } -} - -// TestAutoGenTempAddrRegen tests that temporary SLAAC addresses are -// regenerated. -func TestAutoGenTempAddrRegen(t *testing.T) { - const ( - nicID = 1 - regenAfter = 2 * time.Second - newMinVL = 10 - newMinVLDuration = newMinVL * time.Second - ) - - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime - }() - ipv6.MaxDesyncFactor = 0 - ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration - ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: newMinVLDuration - regenAfter, - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(timeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddr1, newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Stop generating temporary addresses - ndpConfigs.AutoGenTempGlobalAddresses = false - if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else { - ndpEP := ipv6Ep.(ipv6.NDPEndpoint) - ndpEP.SetNDPConfigurations(ndpConfigs) - } - - // Wait for all the temporary addresses to get invalidated. - tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3} - invalidateAfter := newMinVLDuration - 2*regenAfter - for _, addr := range tempAddrs { - // Wait for a deprecation then invalidation event, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation jobs could execute in any - // order. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we shouldn't get a deprecation - // event after. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - } else { - t.Fatalf("got unexpected auto-generated event = %+v", e) - } - case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - - invalidateAfter = regenAfter - } - if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" { - t.Fatal(mismatch) - } -} - -// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's -// regeneration job gets updated when refreshing the address's lifetimes. -func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { - const ( - nicID = 1 - regenAfter = 2 * time.Second - newMinVL = 10 - newMinVLDuration = newMinVL * time.Second - ) - - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime - }() - ipv6.MaxDesyncFactor = 0 - ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration - ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: newMinVLDuration - regenAfter, - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(timeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } - - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddr1, newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Deprecate the prefix. - // - // A new temporary address should be generated after the regeneration - // time has passed since the prefix is deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0)) - expectAutoGenAddrEvent(addr, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): - } - - // Prefer the prefix again. - // - // A new temporary address should immediately be generated since the - // regeneration time has already passed since the last address was generated - // - this regeneration does not depend on a job. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEvent(tempAddr2, newAddr) - - // Increase the maximum lifetimes for temporary addresses to large values - // then refresh the lifetimes of the prefix. - // - // A new address should not be generated after the regeneration time that was - // expected for the previous check. This is because the preferred lifetime for - // the temporary addresses has increased, so it will take more time to - // regenerate a new temporary address. Note, new addresses are only - // regenerated after the preferred lifetime - the regenerate advance duration - // as paased. - ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second - ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second - ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } - ndpEP := ipv6Ep.(ipv6.NDPEndpoint) - ndpEP.SetNDPConfigurations(ndpConfigs) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): - } - - // Set the maximum lifetimes for temporary addresses such that on the next - // RA, the regeneration job gets scheduled again. - // - // The maximum lifetime is the sum of the minimum lifetimes for temporary - // addresses + the time that has already passed since the last address was - // generated so that the regeneration job is needed to generate the next - // address. - newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout - ndpConfigs.MaxTempAddrValidLifetime = newLifetimes - ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes - ndpEP.SetNDPConfigurations(ndpConfigs) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) -} - -// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response -// to a mix of DAD conflicts and NIC-local conflicts. -func TestMixedSLAACAddrConflictRegen(t *testing.T) { - const ( - nicID = 1 - nicName = "nic" - lifetimeSeconds = 9999 - // From stack.maxSLAACAddrLocalRegenAttempts - maxSLAACAddrLocalRegenAttempts = 10 - // We use 2 more addreses than the maximum local regeneration attempts - // because we want to also trigger regeneration in response to a DAD - // conflicts for this test. - maxAddrs = maxSLAACAddrLocalRegenAttempts + 2 - dupAddrTransmits = 1 - retransmitTimer = time.Second - ) - - var tempIIDHistoryWithModifiedEUI64 [header.IIDSize]byte - header.InitialTempIID(tempIIDHistoryWithModifiedEUI64[:], nil, nicID) - - var tempIIDHistoryWithOpaqueIID [header.IIDSize]byte - header.InitialTempIID(tempIIDHistoryWithOpaqueIID[:], nil, nicID) - - prefix, subnet, stableAddrWithModifiedEUI64 := prefixSubnetAddr(0, linkAddr1) - var stableAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix - var tempAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix - var tempAddrsWithModifiedEUI64 [maxAddrs]tcpip.AddressWithPrefix - addrBytes := []byte(subnet.ID()) - for i := 0; i < maxAddrs; i++ { - stableAddrsWithOpaqueIID[i] = tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, uint8(i), nil)), - PrefixLen: header.IIDOffsetInIPv6Address * 8, - } - // When generating temporary addresses, the resolved stable address for the - // SLAAC prefix will be the first address stable address generated for the - // prefix as we will not simulate address conflicts for the stable addresses - // in tests involving temporary addresses. Address conflicts for stable - // addresses will be done in their own tests. - tempAddrsWithOpaqueIID[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithOpaqueIID[:], stableAddrsWithOpaqueIID[0].Address) - tempAddrsWithModifiedEUI64[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithModifiedEUI64[:], stableAddrWithModifiedEUI64.Address) - } - - tests := []struct { - name string - addrs []tcpip.AddressWithPrefix - tempAddrs bool - initialExpect tcpip.AddressWithPrefix - maxAddrs int - nicNameFromID func(tcpip.NICID, string) string - }{ - { - name: "Stable addresses with opaque IIDs", - addrs: stableAddrsWithOpaqueIID[:], - maxAddrs: 1, - nicNameFromID: func(tcpip.NICID, string) string { - return nicName - }, - }, - { - name: "Temporary addresses with opaque IIDs", - addrs: tempAddrsWithOpaqueIID[:], - tempAddrs: true, - initialExpect: stableAddrsWithOpaqueIID[0], - maxAddrs: 1 /* initial (stable) address */ + maxSLAACAddrLocalRegenAttempts, - nicNameFromID: func(tcpip.NICID, string) string { - return nicName - }, - }, - { - name: "Temporary addresses with modified EUI64", - addrs: tempAddrsWithModifiedEUI64[:], - tempAddrs: true, - maxAddrs: 1 /* initial (stable) address */ + maxSLAACAddrLocalRegenAttempts, - initialExpect: stableAddrWithModifiedEUI64, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - // We may receive a deprecated and invalidated event for each SLAAC - // address that is assigned. - autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAddrs*2), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - Clock: clock, - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: test.tempAddrs, - AutoGenAddressConflictRetries: 1, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: test.nicNameFromID, - }, - })}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: llAddr2, - NIC: nicID, - }}) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - manuallyAssignedAddresses := make(map[tcpip.Address]struct{}) - for j := 0; j < len(test.addrs)-1; j++ { - // The NIC will not attempt to generate an address in response to a - // NIC-local conflict after some maximum number of attempts. We skip - // creating a conflict for the address that would be generated as part - // of the last attempt so we can simulate a DAD conflict for this - // address and restart the NIC-local generation process. - if j == maxSLAACAddrLocalRegenAttempts-1 { - continue - } - - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err) - } - - manuallyAssignedAddresses[test.addrs[j].Address] = struct{}{} - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrAsyncEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - if diff := checkAutoGenAddrEvent(<-ndpDisp.autoGenAddrC, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - } - - expectDADEventAsync := func(addr tcpip.Address) { - t.Helper() - - clock.Advance(dupAddrTransmits * retransmitTimer) - if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - } - - // Enable DAD. - ndpDisp.dadC = make(chan ndpDADEvent, 2) - if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else { - ndpEP := ipv6Ep.(stack.DuplicateAddressDetector) - ndpEP.SetDADConfigurations(stack.DADConfigurations{ - DupAddrDetectTransmits: dupAddrTransmits, - RetransmitTimer: retransmitTimer, - }) - } - - // Do SLAAC for prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) - if test.initialExpect != (tcpip.AddressWithPrefix{}) { - expectAutoGenAddrEvent(test.initialExpect, newAddr) - expectDADEventAsync(test.initialExpect.Address) - } - - // The last local generation attempt should succeed, but we introduce a - // DAD failure to restart the local generation process. - addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1] - expectAutoGenAddrAsyncEvent(addr, newAddr) - rxNDPSolicit(e, addr.Address) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - expectAutoGenAddrEvent(addr, invalidatedAddr) - - // The last address generated should resolve DAD. - addr = test.addrs[len(test.addrs)-1] - expectAutoGenAddrAsyncEvent(addr, newAddr) - expectDADEventAsync(addr.Address) - - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - default: - } - - // Wait for all the SLAAC addresses to be invalidated. - clock.Advance(lifetimeSeconds * time.Second) - gotAddresses := make(map[tcpip.Address]struct{}) - for _, a := range s.NICInfo()[nicID].ProtocolAddresses { - gotAddresses[a.AddressWithPrefix.Address] = struct{}{} - } - if diff := cmp.Diff(manuallyAssignedAddresses, gotAddresses); diff != "" { - t.Fatalf("assigned addresses mismatch (-want +got):\n%s", diff) - } - }) - } -} - -// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher, -// channel.Endpoint and stack.Stack. -// -// stack.Stack will have a default route through the router (llAddr3) installed -// and a static link-address (linkAddr3) added to the link address cache for the -// router. -func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { - t.Helper() - ndpDisp := &ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: ndpDisp, - })}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: llAddr3, - NIC: nicID, - }}) - - if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err) - } - return ndpDisp, e, s -} - -// addrForNewConnectionTo returns the local address used when creating a new -// connection to addr. -func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { - t.Helper() - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) - } - defer ep.Close() - ep.SocketOptions().SetV6Only(true) - if err := ep.Connect(addr); err != nil { - t.Fatalf("ep.Connect(%+v): %s", addr, err) - } - got, err := ep.GetLocalAddress() - if err != nil { - t.Fatalf("ep.GetLocalAddress(): %s", err) - } - return got.Addr -} - -// addrForNewConnection returns the local address used when creating a new -// connection. -func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { - t.Helper() - - return addrForNewConnectionTo(t, s, dstAddr) -} - -// addrForNewConnectionWithAddr returns the local address used when creating a -// new connection with a specific local address. -func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { - t.Helper() - - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) - } - defer ep.Close() - ep.SocketOptions().SetV6Only(true) - if err := ep.Bind(addr); err != nil { - t.Fatalf("ep.Bind(%+v): %s", addr, err) - } - if err := ep.Connect(dstAddr); err != nil { - t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) - } - got, err := ep.GetLocalAddress() - if err != nil { - t.Fatalf("ep.GetLocalAddress(): %s", err) - } - return got.Addr -} - -// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when -// receiving a PI with 0 preferred lifetime. -func TestAutoGenAddrDeprecateFromPI(t *testing.T) { - const nicID = 1 - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() - - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } - - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } - - // Receive PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - expectPrimaryAddr(addr1) - - // Deprecate addr for prefix1 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - // addr should still be the primary endpoint as there are no other addresses. - expectPrimaryAddr(addr1) - - // Refresh lifetimes of addr generated from prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) - - // Deprecate addr for prefix2 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr1 should be the primary endpoint now since addr2 is deprecated but - // addr1 is not. - expectPrimaryAddr(addr1) - // addr2 is deprecated but if explicitly requested, it should be used. - fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } - - // Another PI w/ 0 preferred lifetime should not result in a deprecation - // event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } - - // Refresh lifetimes of addr generated from prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) -} - -// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated -// when its preferred lifetime expires. -func TestAutoGenAddrJobDeprecation(t *testing.T) { - const nicID = 1 - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(timeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } - - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() - - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } - - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } - - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) - - // Receive a PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr1) - - // Refresh lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since addr1 is deprecated but - // addr2 is not. - expectPrimaryAddr(addr2) - // addr1 is deprecated but if explicitly requested, it should be used. - fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } - - // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make - // sure we do not get a deprecation event again. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } - - // Refresh lifetimes for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - // addr1 is the primary endpoint again since it is non-deprecated now. - expectPrimaryAddr(addr1) - - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since it is not deprecated. - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } - - // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) - - // Refresh both lifetimes for addr of prefix2 to the same value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - - // Wait for a deprecation then invalidation events, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation handlers could be handled in - // either deprecation then invalidation, or invalidation then deprecation - // (which should be cancelled by the invalidation handler). - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we should not get a deprecation - // event after. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultAsyncNegativeEventTimeout): - } - } else { - t.Fatalf("got unexpected auto-generated event") - } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should not have %s in the list of addresses", addr2) - } - // Should not have any primary endpoints. - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) - } - defer ep.Close() - ep.SocketOptions().SetV6Only(true) - - { - err := ep.Connect(dstAddr) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{}) - } - } -} - -// Tests transitioning a SLAAC address's valid lifetime between finite and -// infinite values. -func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { - const infiniteVLSeconds = 2 - const minVLSeconds = 1 - savedIL := header.NDPInfiniteLifetime - savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL - header.NDPInfiniteLifetime = savedIL - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second - header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - tests := []struct { - name string - infiniteVL uint32 - }{ - { - name: "EqualToInfiniteVL", - infiniteVL: infiniteVLSeconds, - }, - // Our implementation supports changing header.NDPInfiniteLifetime for tests - // such that a packet can be received where the lifetime field has a value - // greater than header.NDPInfiniteLifetime. Because of this, we test to make - // sure that receiving a value greater than header.NDPInfiniteLifetime is - // handled the same as when receiving a value equal to - // header.NDPInfiniteLifetime. - { - name: "MoreThanInfiniteVL", - infiniteVL: infiniteVLSeconds + 1, - }, - } - - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with finite prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - - default: - t.Fatal("expected addr auto gen event") - } - - // Receive an new RA with prefix with infinite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) - - // Receive a new RA with prefix with finite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - - case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout): - t.Fatal("timeout waiting for addr auto gen event") - } - }) - } - }) -} - -// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an -// auto-generated address only gets updated when required to, as specified in -// RFC 4862 section 5.5.3.e. -func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { - const infiniteVL = 4294967295 - const newMinVL = 4 - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second - - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - tests := []struct { - name string - ovl uint32 - nvl uint32 - evl uint32 - }{ - // Should update the VL to the minimum VL for updating if the - // new VL is less than newMinVL but was originally greater than - // it. - { - "LargeVLToVLLessThanMinVLForUpdate", - 9999, - 1, - newMinVL, - }, - { - "LargeVLTo0", - 9999, - 0, - newMinVL, - }, - { - "InfiniteVLToVLLessThanMinVLForUpdate", - infiniteVL, - 1, - newMinVL, - }, - { - "InfiniteVLTo0", - infiniteVL, - 0, - newMinVL, - }, - - // Should not update VL if original VL was less than newMinVL - // and the new VL is also less than newMinVL. - { - "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", - newMinVL - 1, - newMinVL - 3, - newMinVL - 1, - }, - - // Should take the new VL if the new VL is greater than the - // remaining time or is greater than newMinVL. - { - "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", - newMinVL + 5, - newMinVL + 3, - newMinVL + 3, - }, - { - "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", - newMinVL - 3, - newMinVL - 1, - newMinVL - 1, - }, - { - "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", - newMinVL - 3, - newMinVL + 1, - newMinVL + 1, - }, - } - - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix with initial VL, - // test.ovl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - - // Receive an new RA with prefix with new VL, - // test.nvl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) - - // - // Validate that the VL for the address got set - // to test.evl. - // - - // The address should not be invalidated until the effective valid - // lifetime has passed. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout): - } - - // Wait for the invalidation event. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timeout waiting for addr auto gen event") - } - }) - } - }) -} - -// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed -// by the user, its resources will be cleaned up and an invalidation event will -// be sent to the integrator. -func TestAutoGenAddrRemoval(t *testing.T) { - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - // Receive a PI to auto-generate an address. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) - expectAutoGenAddrEvent(addr, newAddr) - - // Removing the address should result in an invalidation event - // immediately. - if err := s.RemoveAddress(1, addr.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err) - } - expectAutoGenAddrEvent(addr, invalidatedAddr) - - // Wait for the original valid lifetime to make sure the original job got - // cancelled/cleaned up. - clock.Advance(lifetimeSeconds * time.Second) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - default: - } -} - -// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously -// assigned to the NIC but is in the permanentExpired state. -func TestAutoGenAddrAfterRemoval(t *testing.T) { - const nicID = 1 - - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() - - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } - - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } - - // Receive a PI to auto-generate addr1 with a large valid and preferred - // lifetime. - const largeLifetimeSeconds = 999 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr1, newAddr) - expectPrimaryAddr(addr1) - - // Add addr2 as a static address. - protoAddr2 := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addr2, - } - if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) - } - // addr2 should be more preferred now since it is at the front of the primary - // list. - expectPrimaryAddr(addr2) - - // Get a route using addr2 to increment its reference count then remove it - // to leave it in the permanentExpired state. - r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) - } - defer r.Release() - if err := s.RemoveAddress(nicID, addr2.Address); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) - } - // addr1 should be preferred again since addr2 is in the expired state. - expectPrimaryAddr(addr1) - - // Receive a PI to auto-generate addr2 as valid and preferred. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr2 should be more preferred now that it is closer to the front of the - // primary list and not deprecated. - expectPrimaryAddr(addr2) - - // Removing the address should result in an invalidation event immediately. - // It should still be in the permanentExpired state because r is still held. - // - // We remove addr2 here to make sure addr2 was marked as a SLAAC address - // (it was previously marked as a static address). - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - // addr1 should be more preferred since addr2 is in the expired state. - expectPrimaryAddr(addr1) - - // Receive a PI to auto-generate addr2 as valid and deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr1 should still be more preferred since addr2 is deprecated, even though - // it is closer to the front of the primary list. - expectPrimaryAddr(addr1) - - // Receive a PI to refresh addr2's preferred lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto gen addr event") - default: - } - // addr2 should be more preferred now that it is not deprecated. - expectPrimaryAddr(addr2) - - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - expectPrimaryAddr(addr1) -} - -// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that -// is already assigned to the NIC, the static address remains. -func TestAutoGenAddrStaticConflict(t *testing.T) { - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Add the address as a static address before SLAAC tries to add it. - if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - - // Receive a PI where the generated address will be the same as the one - // that we already have assigned statically. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") - default: - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - - // Should not get an invalidation event after the PI's invalidation - // time. - clock.Advance(lifetimeSeconds * time.Second) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - default: - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } -} - -func makeSecretKey(t *testing.T) []byte { - secretKey := make([]byte, header.OpaqueIIDSecretKeyMinBytes) - n, err := cryptorand.Read(secretKey) - if err != nil { - t.Fatalf("cryptorand.Read(_): %s", err) - } - if l := len(secretKey); n != l { - t.Fatalf("got cryptorand.Read(_) = (%d, nil), want = (%d, nil)", n, l) - } - return secretKey -} - -// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use -// opaque interface identifiers when configured to do so. -func TestAutoGenAddrWithOpaqueIID(t *testing.T) { - const nicID = 1 - const nicName = "nic1" - - secretKey := makeSecretKey(t) - - prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1) - prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1) - // addr1 and addr2 are the addresses that are expected to be generated when - // stack.Stack is configured to generate opaque interface identifiers as - // defined by RFC 7217. - addrBytes := []byte(subnet1.ID()) - addr1 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)), - PrefixLen: 64, - } - addrBytes = []byte(subnet2.ID()) - addr2 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)), - PrefixLen: 64, - } - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - SecretKey: secretKey, - }, - })}, - Clock: clock, - }) - opts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v, _) = %s", nicID, opts, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - // Receive an RA with prefix1 in a PI. - const validLifetimeSecondPrefix1 = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - - // Receive an RA with prefix2 in a PI with a large valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - - // Wait for addr of prefix1 to be invalidated. - clock.Advance(validLifetimeSecondPrefix1 * time.Second) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } -} - -func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { - const nicID = 1 - const nicName = "nic" - const dadTransmits = 1 - const retransmitTimer = time.Second - const maxMaxRetries = 3 - const lifetimeSeconds = 10 - - // Needed for the temporary address sub test. - savedMaxDesync := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesync - }() - ipv6.MaxDesyncFactor = time.Nanosecond - - secretKey := makeSecretKey(t) - - prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) - - addrForSubnet := func(subnet tcpip.Subnet, dadCounter uint8) tcpip.AddressWithPrefix { - addrBytes := []byte(subnet.ID()) - return tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, dadCounter, secretKey)), - PrefixLen: 64, - } - } - - expectAutoGenAddrEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } - - expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { - t.Helper() - - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, res); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - } - - expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { - t.Helper() - - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, res); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") - } - } - - stableAddrForTempAddrTest := addrForSubnet(subnet, 0) - - addrTypes := []struct { - name string - ndpConfigs ipv6.NDPConfigurations - autoGenLinkLocal bool - prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix - addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix - }{ - { - name: "Global address", - ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { - // Receive an RA with prefix1 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) - return nil - - }, - addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { - return addrForSubnet(subnet, dadCounter) - }, - }, - { - name: "LinkLocal address", - ndpConfigs: ipv6.NDPConfigurations{}, - autoGenLinkLocal: true, - prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { - return nil - }, - addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { - return addrForSubnet(header.IPv6LinkLocalPrefix.Subnet(), dadCounter) - }, - }, - { - name: "Temporary address", - ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { - header.InitialTempIID(tempIIDHistory, nil, nicID) - - // Generate a stable SLAAC address so temporary addresses will be - // generated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr) - expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{}) - - // The stable address will be assigned throughout the test. - return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest} - }, - addrGenFn: func(_ uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix { - return header.GenerateTempIPv6SLAACAddr(tempIIDHistory, stableAddrForTempAddrTest.Address) - }, - }, - } - - for _, addrType := range addrTypes { - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the parallel - // tests complete and limit the number of parallel tests running at the same - // time to reduce flakes. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run(addrType.name, func(t *testing.T) { - for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { - for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { - maxRetries := maxRetries - numFailures := numFailures - addrType := addrType - - t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - ndpConfigs := addrType.ndpConfigs - ndpConfigs.AutoGenAddressConflictRetries = maxRetries - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: addrType.autoGenLinkLocal, - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - SecretKey: secretKey, - }, - })}, - }) - opts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) - } - - var tempIIDHistory [header.IIDSize]byte - stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:]) - - // Simulate DAD conflicts so the address is regenerated. - for i := uint8(0); i < numFailures; i++ { - addr := addrType.addrGenFn(i, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) - - // Should not have any new addresses assigned to the NIC. - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // Simulate a DAD conflict. - rxNDPSolicit(e, addr.Address) - expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) - expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{}) - - // Attempting to add the address manually should not fail if the - // address's state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) - } - if err := s.RemoveAddress(nicID, addr.Address); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) - } - expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADAborted{}) - } - - // Should not have any new addresses assigned to the NIC. - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { - t.Fatal(mismatch) - } - - // If we had less failures than generation attempts, we should have - // an address after DAD resolves. - if maxRetries+1 > numFailures { - addr := addrType.addrGenFn(numFailures, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) - expectDADEventAsync(t, &ndpDisp, addr.Address, &stack.DADSucceeded{}) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" { - t.Fatal(mismatch) - } - } - - // Should not attempt address generation again. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - }) - } - } - }) - } -} - -// TestAutoGenAddrWithEUI64IIDNoDADRetries tests that a regeneration attempt is -// not made for SLAAC addresses generated with an IID based on the NIC's link -// address. -func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { - const nicID = 1 - const dadTransmits = 1 - const retransmitTimer = time.Second - const maxRetries = 3 - const lifetimeSeconds = 10 - - prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) - - addrTypes := []struct { - name string - ndpConfigs ipv6.NDPConfigurations - autoGenLinkLocal bool - subnet tcpip.Subnet - triggerSLAACFn func(e *channel.Endpoint) - }{ - { - name: "Global address", - ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenAddressConflictRetries: maxRetries, - }, - subnet: subnet, - triggerSLAACFn: func(e *channel.Endpoint) { - // Receive an RA with prefix1 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) - - }, - }, - { - name: "LinkLocal address", - ndpConfigs: ipv6.NDPConfigurations{ - AutoGenAddressConflictRetries: maxRetries, - }, - autoGenLinkLocal: true, - subnet: header.IPv6LinkLocalPrefix.Subnet(), - triggerSLAACFn: func(e *channel.Endpoint) {}, - }, - } - - for _, addrType := range addrTypes { - addrType := addrType - - t.Run(addrType.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: addrType.ndpConfigs, - NDPDisp: &ndpDisp, - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - addrType.triggerSLAACFn(e) - - addrBytes := []byte(addrType.subnet.ID()) - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr1, addrBytes[header.IIDOffsetInIPv6Address:]) - addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(addrBytes), - PrefixLen: 64, - } - expectAutoGenAddrEvent(addr, newAddr) - - // Simulate a DAD conflict. - rxNDPSolicit(e, addr.Address) - expectAutoGenAddrEvent(addr, invalidatedAddr) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - - // Should not attempt address regeneration. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - default: - } - }) - } -} - -// TestAutoGenAddrContinuesLifetimesAfterRetry tests that retrying address -// generation in response to DAD conflicts does not refresh the lifetimes. -func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { - const nicID = 1 - const nicName = "nic" - const dadTransmits = 1 - const retransmitTimer = 2 * time.Second - const failureTimer = time.Second - const maxRetries = 1 - const lifetimeSeconds = 5 - - secretKey := makeSecretKey(t) - - prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenAddressConflictRetries: maxRetries, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - SecretKey: secretKey, - }, - })}, - Clock: clock, - }) - opts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - // Receive an RA with prefix in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) - - addrBytes := []byte(subnet.ID()) - addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 0, secretKey)), - PrefixLen: 64, - } - expectAutoGenAddrEvent(addr, newAddr) - - // Simulate a DAD conflict after some time has passed. - clock.Advance(failureTimer) - rxNDPSolicit(e, addr.Address) - expectAutoGenAddrEvent(addr, invalidatedAddr) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DAD event") - } - - // Let the next address resolve. - addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey)) - expectAutoGenAddrEvent(addr, newAddr) - clock.Advance(dadTransmits * retransmitTimer) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for DAD event") - } - - // Address should be deprecated/invalidated after the lifetime expires. - // - // Note, the remaining lifetime is calculated from when the PI was first - // processed. Since we wait for some time before simulating a DAD conflict - // and more time for the new address to resolve, the new address is only - // expected to be valid for the remaining time. The DAD conflict should - // not have reset the lifetimes. - // - // We expect either just the invalidation event or the deprecation event - // followed by the invalidation event. - clock.Advance(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer) - select { - case e := <-ndpDisp.autoGenAddrC: - if e.eventType == deprecatedAddr { - if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation") - } - } else { - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - } - default: - t.Fatal("timed out waiting for auto gen addr event") - } -} - -// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event -// to the integrator when an RA is received with the NDP Recursive DNS Server -// option with at least one valid address. -func TestNDPRecursiveDNSServerDispatch(t *testing.T) { - tests := []struct { - name string - opt header.NDPRecursiveDNSServer - expected *ndpRDNSS - }{ - { - "Unspecified", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }), - nil, - }, - { - "Multicast", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, - }), - nil, - }, - { - "OptionTooSmall", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 1, 2, 3, 4, 5, 6, 7, 8, - }), - nil, - }, - { - "0Addresses", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - }), - nil, - }, - { - "Valid1Address", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 2, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, - }), - &ndpRDNSS{ - []tcpip.Address{ - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", - }, - 2 * time.Second, - }, - }, - { - "Valid2Addresses", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 1, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, - }), - &ndpRDNSS{ - []tcpip.Address{ - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", - }, - time.Second, - }, - }, - { - "Valid3Addresses", - header.NDPRecursiveDNSServer([]byte{ - 0, 0, - 0, 0, 0, 0, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, - 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3, - }), - &ndpRDNSS{ - []tcpip.Address{ - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", - "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03", - }, - 0, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - // We do not expect more than a single RDNSS - // event at any time for this test. - rdnssC: make(chan ndpRDNSSEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - }, - NDPDisp: &ndpDisp, - })}, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt})) - - if test.expected != nil { - select { - case e := <-ndpDisp.rdnssC: - if e.nicID != 1 { - t.Errorf("got rdnss nicID = %d, want = 1", e.nicID) - } - if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" { - t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff) - } - if e.rdnss.lifetime != test.expected.lifetime { - t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime) - } - default: - t.Fatal("expected an RDNSS option event") - } - } - - // Should have no more RDNSS options. - select { - case e := <-ndpDisp.rdnssC: - t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e) - default: - } - }) - } -} - -// TestNDPDNSSearchListDispatch tests that the integrator is informed when an -// NDP DNS Search List option is received with at least one domain name in the -// search list. -func TestNDPDNSSearchListDispatch(t *testing.T) { - const nicID = 1 - - ndpDisp := ndpDispatcher{ - dnsslC: make(chan ndpDNSSLEvent, 3), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - }, - NDPDisp: &ndpDisp, - })}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - optSer := header.NDPOptionsSerializer{ - header.NDPDNSSearchList([]byte{ - 0, 0, - 0, 0, 0, 0, - 2, 'h', 'i', - 0, - }), - header.NDPDNSSearchList([]byte{ - 0, 0, - 0, 0, 0, 1, - 1, 'i', - 0, - 2, 'a', 'm', - 2, 'm', 'e', - 0, - }), - header.NDPDNSSearchList([]byte{ - 0, 0, - 0, 0, 1, 0, - 3, 'x', 'y', 'z', - 0, - 5, 'h', 'e', 'l', 'l', 'o', - 5, 'w', 'o', 'r', 'l', 'd', - 0, - 4, 't', 'h', 'i', 's', - 2, 'i', 's', - 1, 'a', - 4, 't', 'e', 's', 't', - 0, - }), - } - expected := []struct { - domainNames []string - lifetime time.Duration - }{ - { - domainNames: []string{ - "hi", - }, - lifetime: 0, - }, - { - domainNames: []string{ - "i", - "am.me", - }, - lifetime: time.Second, - }, - { - domainNames: []string{ - "xyz", - "hello.world", - "this.is.a.test", - }, - lifetime: 256 * time.Second, - }, - } - - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) - - for i, expected := range expected { - select { - case dnssl := <-ndpDisp.dnsslC: - if dnssl.nicID != nicID { - t.Errorf("got %d-th dnssl nicID = %d, want = %d", i, dnssl.nicID, nicID) - } - if diff := cmp.Diff(dnssl.domainNames, expected.domainNames); diff != "" { - t.Errorf("%d-th dnssl domain names mismatch (-want +got):\n%s", i, diff) - } - if dnssl.lifetime != expected.lifetime { - t.Errorf("got %d-th dnssl lifetime = %s, want = %s", i, dnssl.lifetime, expected.lifetime) - } - default: - t.Fatal("expected a DNSSL event") - } - } - - // Should have no more DNSSL options. - select { - case <-ndpDisp.dnsslC: - t.Fatal("unexpectedly got a DNSSL event") - default: - } -} - -func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { - const ( - lifetimeSeconds = 999 - nicID = 1 - ) - - ndpDisp := ndpDispatcher{ - offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), - prefixC: make(chan ndpPrefixEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: true, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverDefaultRouters: true, - DiscoverOnLinkPrefixes: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - e1 := channel.New(0, header.IPv6MinimumMTU, linkAddr1) - if err := s.CreateNIC(nicID, e1); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - llAddr := tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen} - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, llAddr, newAddr); diff != "" { - t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) - } - default: - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", llAddr, nicID) - } - - prefix, subnet, addr := prefixSubnetAddr(0, linkAddr1) - e1.InjectInbound( - header.IPv6ProtocolNumber, - raBufWithPI( - llAddr3, - lifetimeSeconds, - prefix, - true, /* onLink */ - true, /* auto */ - lifetimeSeconds, - lifetimeSeconds, - ), - ) - select { - case e := <-ndpDisp.offLinkRouteC: - if diff := checkOffLinkRouteEvent(e, nicID, llAddr3, header.MediumRoutePreference, true /* discovered */); diff != "" { - t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) - } - default: - t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID) - } - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" { - t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) - } - default: - t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID) - } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) - } - default: - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", addr, nicID) - } - - // Enabling or disabling forwarding should not invalidate discovered prefixes - // or routers, or auto-generated address. - for _, forwarding := range [...]bool{true, false} { - t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) { - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) - } - select { - case e := <-ndpDisp.offLinkRouteC: - t.Errorf("unexpected off-link route event = %#v", e) - default: - } - select { - case e := <-ndpDisp.prefixC: - t.Errorf("unexpected prefix event = %#v", e) - default: - } - select { - case e := <-ndpDisp.autoGenAddrC: - t.Errorf("unexpected auto-gen addr event = %#v", e) - default: - } - }) - } -} - -func TestCleanupNDPState(t *testing.T) { - const ( - lifetimeSeconds = 5 - maxRouterAndPrefixEvents = 4 - nicID1 = 1 - nicID2 = 2 - ) - - prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1) - e2Addr1 := addrForSubnet(subnet1, linkAddr2) - e2Addr2 := addrForSubnet(subnet2, linkAddr2) - llAddrWithPrefix1 := tcpip.AddressWithPrefix{ - Address: llAddr1, - PrefixLen: 64, - } - llAddrWithPrefix2 := tcpip.AddressWithPrefix{ - Address: llAddr2, - PrefixLen: 64, - } - - tests := []struct { - name string - cleanupFn func(t *testing.T, s *stack.Stack) - keepAutoGenLinkLocal bool - maxAutoGenAddrEvents int - skipFinalAddrCheck bool - }{ - // A NIC should cleanup all NDP state when it is disabled. - { - name: "Disable NIC", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - - if err := s.DisableNIC(nicID1); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) - } - if err := s.DisableNIC(nicID2); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) - } - }, - keepAutoGenLinkLocal: false, - maxAutoGenAddrEvents: 6, - }, - - // A NIC should cleanup all NDP state when it is removed. - { - name: "Remove NIC", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - - if err := s.RemoveNIC(nicID1); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID1, err) - } - if err := s.RemoveNIC(nicID2); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID2, err) - } - }, - keepAutoGenLinkLocal: false, - maxAutoGenAddrEvents: 6, - // The NICs are removed so we can't check their addresses after calling - // stopFn. - skipFinalAddrCheck: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - offLinkRouteC: make(chan ndpOffLinkRouteEvent, maxRouterAndPrefixEvents), - prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), - } - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: true, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverDefaultRouters: true, - DiscoverOnLinkPrefixes: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - expectOffLinkRouteEvent := func() (bool, ndpOffLinkRouteEvent) { - select { - case e := <-ndpDisp.offLinkRouteC: - return true, e - default: - } - - return false, ndpOffLinkRouteEvent{} - } - - expectPrefixEvent := func() (bool, ndpPrefixEvent) { - select { - case e := <-ndpDisp.prefixC: - return true, e - default: - } - - return false, ndpPrefixEvent{} - } - - expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { - select { - case e := <-ndpDisp.autoGenAddrC: - return true, e - default: - } - - return false, ndpAutoGenAddrEvent{} - } - - e1 := channel.New(0, 1280, linkAddr1) - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) - } - // We have other tests that make sure we receive the *correct* events - // on normal discovery of routers/prefixes, and auto-generated - // addresses. Here we just make sure we get an event and let other tests - // handle the correctness check. - expectAutoGenAddrEvent() - - e2 := channel.New(0, 1280, linkAddr2) - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) - } - expectAutoGenAddrEvent() - - // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and - // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from - // llAddr4) to discover multiple routers and prefixes, and auto-gen - // multiple addresses. - - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectOffLinkRouteEvent(); !ok { - t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) - } - - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectOffLinkRouteEvent(); !ok { - t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) - } - - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectOffLinkRouteEvent(); !ok { - t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) - } - - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectOffLinkRouteEvent(); !ok { - t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) - } - - // We should have the auto-generated addresses added. - nicinfo := s.NICInfo() - nic1Addrs := nicinfo[nicID1].ProtocolAddresses - nic2Addrs := nicinfo[nicID2].ProtocolAddresses - if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } - - // We can't proceed any further if we already failed the test (missing - // some discovery/auto-generated address events or addresses). - if t.Failed() { - t.FailNow() - } - - test.cleanupFn(t, s) - - // Collect invalidation events after having NDP state cleaned up. - gotOffLinkRouteEvents := make(map[ndpOffLinkRouteEvent]int) - for i := 0; i < maxRouterAndPrefixEvents; i++ { - ok, e := expectOffLinkRouteEvent() - if !ok { - t.Errorf("expected %d off-link route events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) - break - } - gotOffLinkRouteEvents[e]++ - } - gotPrefixEvents := make(map[ndpPrefixEvent]int) - for i := 0; i < maxRouterAndPrefixEvents; i++ { - ok, e := expectPrefixEvent() - if !ok { - t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) - break - } - gotPrefixEvents[e]++ - } - gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) - for i := 0; i < test.maxAutoGenAddrEvents; i++ { - ok, e := expectAutoGenAddrEvent() - if !ok { - t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i) - break - } - gotAutoGenAddrEvents[e]++ - } - - // No need to proceed any further if we already failed the test (missing - // some invalidation events). - if t.Failed() { - t.FailNow() - } - - expectedOffLinkRouteEvents := map[ndpOffLinkRouteEvent]int{ - {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, - {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, - {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, - {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, - } - if diff := cmp.Diff(expectedOffLinkRouteEvents, gotOffLinkRouteEvents); diff != "" { - t.Errorf("off-link route events mismatch (-want +got):\n%s", diff) - } - expectedPrefixEvents := map[ndpPrefixEvent]int{ - {nicID: nicID1, prefix: subnet1, discovered: false}: 1, - {nicID: nicID1, prefix: subnet2, discovered: false}: 1, - {nicID: nicID2, prefix: subnet1, discovered: false}: 1, - {nicID: nicID2, prefix: subnet2, discovered: false}: 1, - } - if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { - t.Errorf("prefix events mismatch (-want +got):\n%s", diff) - } - expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ - {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, - } - - if !test.keepAutoGenLinkLocal { - expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1 - expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1 - } - - if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { - t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) - } - - if !test.skipFinalAddrCheck { - // Make sure the auto-generated addresses got removed. - nicinfo = s.NICInfo() - nic1Addrs = nicinfo[nicID1].ProtocolAddresses - nic2Addrs = nicinfo[nicID2].ProtocolAddresses - if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal { - if test.keepAutoGenLinkLocal { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } else { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - } - if containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal { - if test.keepAutoGenLinkLocal { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } else { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - } - if containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } - } - - // Should not get any more events (invalidation timers should have been - // cancelled when the NDP state was cleaned up). - clock.Advance(lifetimeSeconds * time.Second) - select { - case <-ndpDisp.offLinkRouteC: - t.Error("unexpected off-link route event") - default: - } - select { - case <-ndpDisp.prefixC: - t.Error("unexpected prefix event") - default: - } - select { - case <-ndpDisp.autoGenAddrC: - t.Error("unexpected auto-generated address event") - default: - } - }) - } -} - -// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly -// informed when new information about what configurations are available via -// DHCPv6 is learned. -func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { - const nicID = 1 - - ndpDisp := ndpDispatcher{ - dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectDHCPv6Event := func(configuration ipv6.DHCPv6ConfigurationFromNDPRA) { - t.Helper() - select { - case e := <-ndpDisp.dhcpv6ConfigurationC: - if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" { - t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected DHCPv6 configuration event") - } - } - - expectNoDHCPv6Event := func() { - t.Helper() - select { - case <-ndpDisp.dhcpv6ConfigurationC: - t.Fatal("unexpected DHCPv6 configuration event") - default: - } - } - - // Even if the first RA reports no DHCPv6 configurations are available, the - // dispatcher should get an event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectDHCPv6Event(ipv6.DHCPv6NoConfiguration) - // Receiving the same update again should not result in an event to the - // dispatcher. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Other - // Configurations. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Managed Address. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectDHCPv6Event(ipv6.DHCPv6ManagedAddress) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to none. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectDHCPv6Event(ipv6.DHCPv6NoConfiguration) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Managed Address. - // - // Note, when the M flag is set, the O flag is redundant. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectDHCPv6Event(ipv6.DHCPv6ManagedAddress) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectNoDHCPv6Event() - // Even though the DHCPv6 flags are different, the effective configuration is - // the same so we should not receive a new event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectNoDHCPv6Event() - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectNoDHCPv6Event() - - // Receive an RA that updates the DHCPv6 configuration to Other - // Configurations. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectNoDHCPv6Event() - - // Cycling the NIC should cause the last DHCPv6 configuration to be cleared. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - // Receive an RA that updates the DHCPv6 configuration to Other - // Configurations. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectNoDHCPv6Event() -} - -var _ rand.Source = (*savingRandSource)(nil) - -type savingRandSource struct { - s rand.Source - - lastInt63 int64 -} - -func (d *savingRandSource) Int63() int64 { - i := d.s.Int63() - d.lastInt63 = i - return i -} -func (d *savingRandSource) Seed(seed int64) { - d.s.Seed(seed) -} - -// TestRouterSolicitation tests the initial Router Solicitations that are sent -// when a NIC newly becomes enabled. -func TestRouterSolicitation(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - linkHeaderLen uint16 - linkAddr tcpip.LinkAddress - nicAddr tcpip.Address - expectedSrcAddr tcpip.Address - expectedNDPOpts []header.NDPOption - maxRtrSolicit uint8 - rtrSolicitInt time.Duration - effectiveRtrSolicitInt time.Duration - maxRtrSolicitDelay time.Duration - effectiveMaxRtrSolicitDelay time.Duration - }{ - { - name: "Single RS with 2s delay and interval", - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 1, - rtrSolicitInt: 2 * time.Second, - effectiveRtrSolicitInt: 2 * time.Second, - maxRtrSolicitDelay: 2 * time.Second, - effectiveMaxRtrSolicitDelay: 2 * time.Second, - }, - { - name: "Single RS with 4s delay and interval", - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 1, - rtrSolicitInt: 4 * time.Second, - effectiveRtrSolicitInt: 4 * time.Second, - maxRtrSolicitDelay: 4 * time.Second, - effectiveMaxRtrSolicitDelay: 4 * time.Second, - }, - { - name: "Two RS with delay", - linkHeaderLen: 1, - nicAddr: llAddr1, - expectedSrcAddr: llAddr1, - maxRtrSolicit: 2, - rtrSolicitInt: 2 * time.Second, - effectiveRtrSolicitInt: 2 * time.Second, - maxRtrSolicitDelay: 500 * time.Millisecond, - effectiveMaxRtrSolicitDelay: 500 * time.Millisecond, - }, - { - name: "Single RS without delay", - linkHeaderLen: 2, - linkAddr: linkAddr1, - nicAddr: llAddr1, - expectedSrcAddr: llAddr1, - expectedNDPOpts: []header.NDPOption{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }, - maxRtrSolicit: 1, - rtrSolicitInt: 2 * time.Second, - effectiveRtrSolicitInt: 2 * time.Second, - maxRtrSolicitDelay: 0, - effectiveMaxRtrSolicitDelay: 0, - }, - { - name: "Two RS without delay and invalid zero interval", - linkHeaderLen: 3, - linkAddr: linkAddr1, - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 2, - rtrSolicitInt: 0, - effectiveRtrSolicitInt: 4 * time.Second, - maxRtrSolicitDelay: 0, - effectiveMaxRtrSolicitDelay: 0, - }, - { - name: "Three RS without delay", - linkAddr: linkAddr1, - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 3, - rtrSolicitInt: 500 * time.Millisecond, - effectiveRtrSolicitInt: 500 * time.Millisecond, - maxRtrSolicitDelay: 0, - effectiveMaxRtrSolicitDelay: 0, - }, - { - name: "Two RS with invalid negative delay", - linkAddr: linkAddr1, - expectedSrcAddr: header.IPv6Any, - maxRtrSolicit: 2, - rtrSolicitInt: time.Second, - effectiveRtrSolicitInt: time.Second, - maxRtrSolicitDelay: -3 * time.Second, - effectiveMaxRtrSolicitDelay: time.Second, - }, - } - - subTests := []struct { - name string - handleRAs ipv6.HandleRAsConfiguration - afterFirstRS func(*testing.T, *stack.Stack) - }{ - { - name: "Handle RAs when forwarding disabled", - handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - afterFirstRS: func(*testing.T, *stack.Stack) {}, - }, - - // Enabling forwarding when RAs are always configured to be handled - // should not stop router solicitations. - { - name: "Handle RAs always", - handleRAs: ipv6.HandlingRAsAlwaysEnabled, - afterFirstRS: func(t *testing.T, s *stack.Stack) { - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) - } - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - clock := faketime.NewManualClock() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, - } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() - - clock.Advance(timeout) - p, ok := e.Read() - if !ok { - t.Fatal("expected router solicitation packet") - } - - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) - - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() - - clock.Advance(timeout) - if p, ok := e.Read(); ok { - t.Fatalf("unexpectedly got a packet = %#v", p) - } - } - randSource := savingRandSource{ - s: rand.NewSource(time.Now().UnixNano()), - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: subTest.handleRAs, - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - })}, - Clock: clock, - RandSource: &randSource, - }) - - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } - } - - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining != 0 { - maxRtrSolicitDelay := test.maxRtrSolicitDelay - if maxRtrSolicitDelay < 0 { - maxRtrSolicitDelay = ipv6.DefaultNDPConfigurations().MaxRtrSolicitationDelay - } - var actualRtrSolicitDelay time.Duration - if maxRtrSolicitDelay != 0 { - actualRtrSolicitDelay = time.Duration(randSource.lastInt63) % maxRtrSolicitDelay - } - waitForPkt(actualRtrSolicitDelay) - remaining-- - } - - subTest.afterFirstRS(t, s) - - for ; remaining != 0; remaining-- { - if test.effectiveRtrSolicitInt != 0 { - waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) - waitForPkt(time.Nanosecond) - } else { - waitForPkt(0) - } - } - - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt) - } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay) - } - - if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) - } - }) - } - }) - } -} - -func TestStopStartSolicitingRouters(t *testing.T) { - const nicID = 1 - const delay = 0 - const interval = 500 * time.Millisecond - const maxRtrSolicitations = 3 - - tests := []struct { - name string - startFn func(t *testing.T, s *stack.Stack) - // first is used to tell stopFn that it is being called for the first time - // after router solicitations were last enabled. - stopFn func(t *testing.T, s *stack.Stack, first bool) - }{ - // Tests that when forwarding is enabled or disabled, router solicitations - // are stopped or started, respectively. - { - name: "Enable and disable forwarding", - startFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, false); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", ipv6.ProtocolNumber, err) - } - }, - stopFn: func(t *testing.T, s *stack.Stack, _ bool) { - t.Helper() - - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) - } - }, - }, - - // Tests that when a NIC is enabled or disabled, router solicitations - // are started or stopped, respectively. - { - name: "Enable and disable NIC", - startFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - }, - stopFn: func(t *testing.T, s *stack.Stack, _ bool) { - t.Helper() - - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - }, - }, - - // Tests that when a NIC is removed, router solicitations are stopped. We - // cannot start router solications on a removed NIC. - { - name: "Remove NIC", - stopFn: func(t *testing.T, s *stack.Stack, first bool) { - t.Helper() - - // Only try to remove the NIC the first time stopFn is called since it's - // impossible to remove an already removed NIC. - if !first { - return - } - - if err := s.RemoveNIC(nicID); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) - } - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(maxRtrSolicitations, 1280, linkAddr1) - waitForPkt := func(clock *faketime.ManualClock, timeout time.Duration) { - t.Helper() - - clock.Advance(timeout) - p, ok := e.Read() - if !ok { - t.Fatal("timed out waiting for packet") - } - - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS()) - } - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: interval, - MaxRtrSolicitationDelay: delay, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // Stop soliciting routers. - test.stopFn(t, s, true /* first */) - clock.Advance(delay) - if _, ok := e.Read(); ok { - // A single RS may have been sent before solicitations were stopped. - clock.Advance(interval) - if _, ok = e.Read(); ok { - t.Fatal("should not have sent more than one RS message") - } - } - - // Stopping router solicitations after it has already been stopped should - // do nothing. - test.stopFn(t, s, false /* first */) - clock.Advance(delay) - if _, ok := e.Read(); ok { - t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") - } - - // If test.startFn is nil, there is no way to restart router solications. - if test.startFn == nil { - return - } - - // Start soliciting routers. - test.startFn(t, s) - waitForPkt(clock, delay) - waitForPkt(clock, interval) - waitForPkt(clock, interval) - clock.Advance(interval) - if _, ok := e.Read(); ok { - t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") - } - - // Starting router solicitations after it has already completed should do - // nothing. - test.startFn(t, s) - clock.Advance(interval) - if _, ok := e.Read(); ok { - t.Fatal("unexpectedly got a packet after finishing router solicitations") - } - }) - } -} diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go deleted file mode 100644 index 7de25fe37..000000000 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ /dev/null @@ -1,1584 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - "math" - "math/rand" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" -) - -const ( - // entryStoreSize is the default number of entries that will be generated and - // added to the entry store. This number needs to be larger than the size of - // the neighbor cache to give ample opportunity for verifying behavior during - // cache overflows. Four times the size of the neighbor cache allows for - // three complete cache overflows. - entryStoreSize = 4 * neighborCacheSize - - // typicalLatency is the typical latency for an ARP or NDP packet to travel - // to a router and back. - typicalLatency = time.Millisecond - - // testEntryBroadcastAddr is a special address that indicates a packet should - // be sent to all nodes. - testEntryBroadcastAddr = tcpip.Address("broadcast") - - // testEntryBroadcastLinkAddr is a special link address sent back to - // multicast neighbor probes. - testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") - - // infiniteDuration indicates that a task will not occur in our lifetime. - infiniteDuration = time.Duration(math.MaxInt64) -) - -// unorderedEventsDiffOpts returns options passed to cmp.Diff to sort slices of -// events for cases where ordering must be ignored. -func unorderedEventsDiffOpts() []cmp.Option { - return []cmp.Option{ - cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { - return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0 - }), - } -} - -// unorderedEntriesDiffOpts returns options passed to cmp.Diff to sort slices of -// entries for cases where ordering must be ignored. -func unorderedEntriesDiffOpts() []cmp.Option { - return []cmp.Option{ - cmpopts.SortSlices(func(a, b NeighborEntry) bool { - return strings.Compare(string(a.Addr), string(b.Addr)) < 0 - }), - } -} - -func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *testNeighborResolver { - config.resetInvalidFields() - rng := rand.New(rand.NewSource(time.Now().UnixNano())) - linkRes := &testNeighborResolver{ - clock: clock, - entries: newTestEntryStore(), - delay: typicalLatency, - } - linkRes.neigh.init(&nic{ - stack: &Stack{ - clock: clock, - nudDisp: nudDisp, - nudConfigs: config, - randomGenerator: rng, - }, - id: 1, - stats: makeNICStats(tcpip.NICStats{}.FillIn()), - }, linkRes) - return linkRes -} - -// testEntryStore contains a set of IP to NeighborEntry mappings. -type testEntryStore struct { - mu sync.RWMutex - entriesMap map[tcpip.Address]NeighborEntry -} - -func toAddress(i uint16) tcpip.Address { - return tcpip.Address([]byte{ - 1, - 0, - byte(i >> 8), - byte(i), - }) -} - -func toLinkAddress(i uint16) tcpip.LinkAddress { - return tcpip.LinkAddress([]byte{ - 1, - 0, - 0, - 0, - byte(i >> 8), - byte(i), - }) -} - -// newTestEntryStore returns a testEntryStore pre-populated with entries. -func newTestEntryStore() *testEntryStore { - store := &testEntryStore{ - entriesMap: make(map[tcpip.Address]NeighborEntry), - } - for i := uint16(0); i < entryStoreSize; i++ { - addr := toAddress(i) - linkAddr := toLinkAddress(i) - - store.entriesMap[addr] = NeighborEntry{ - Addr: addr, - LinkAddr: linkAddr, - } - } - return store -} - -// size returns the number of entries in the store. -func (s *testEntryStore) size() uint16 { - s.mu.RLock() - defer s.mu.RUnlock() - return uint16(len(s.entriesMap)) -} - -// entry returns the entry at index i. Returns an empty entry and false if i is -// out of bounds. -func (s *testEntryStore) entry(i uint16) (NeighborEntry, bool) { - return s.entryByAddr(toAddress(i)) -} - -// entryByAddr returns the entry matching addr for situations when the index is -// not available. Returns an empty entry and false if no entries match addr. -func (s *testEntryStore) entryByAddr(addr tcpip.Address) (NeighborEntry, bool) { - s.mu.RLock() - defer s.mu.RUnlock() - entry, ok := s.entriesMap[addr] - return entry, ok -} - -// entries returns all entries in the store. -func (s *testEntryStore) entries() []NeighborEntry { - entries := make([]NeighborEntry, 0, len(s.entriesMap)) - s.mu.RLock() - defer s.mu.RUnlock() - for i := uint16(0); i < entryStoreSize; i++ { - addr := toAddress(i) - if entry, ok := s.entriesMap[addr]; ok { - entries = append(entries, entry) - } - } - return entries -} - -// set modifies the link addresses of an entry. -func (s *testEntryStore) set(i uint16, linkAddr tcpip.LinkAddress) { - addr := toAddress(i) - s.mu.Lock() - defer s.mu.Unlock() - if entry, ok := s.entriesMap[addr]; ok { - entry.LinkAddr = linkAddr - s.entriesMap[addr] = entry - } -} - -// testNeighborResolver implements LinkAddressResolver to emulate sending a -// neighbor probe. -type testNeighborResolver struct { - clock tcpip.Clock - neigh neighborCache - entries *testEntryStore - delay time.Duration - onLinkAddressRequest func() - dropReplies bool -} - -var _ LinkAddressResolver = (*testNeighborResolver)(nil) - -func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error { - if !r.dropReplies { - // Delay handling the request to emulate network latency. - r.clock.AfterFunc(r.delay, func() { - r.fakeRequest(targetAddr) - }) - } - - // Execute post address resolution action, if available. - if f := r.onLinkAddressRequest; f != nil { - f() - } - return nil -} - -// fakeRequest emulates handling a response for a link address request. -func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) { - if entry, ok := r.entries.entryByAddr(addr); ok { - r.neigh.handleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - } -} - -func (*testNeighborResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == testEntryBroadcastAddr { - return testEntryBroadcastLinkAddr, true - } - return "", false -} - -func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return 0 -} - -func TestNeighborCacheGetConfig(t *testing.T) { - nudDisp := testNUDDispatcher{} - c := DefaultNUDConfigurations() - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, c, clock) - - if got, want := linkRes.neigh.config(), c; got != want { - t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestNeighborCacheSetConfig(t *testing.T) { - nudDisp := testNUDDispatcher{} - c := DefaultNUDConfigurations() - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, c, clock) - - c.MinRandomFactor = 1 - c.MaxRandomFactor = 1 - linkRes.neigh.setConfig(c) - - if got, want := linkRes.neigh.config(), c; got != want { - t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.ManualClock, linkRes *testNeighborResolver, entry NeighborEntry, removed []NeighborEntry) error { - var gotLinkResolutionResult LinkResolutionResult - - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - gotLinkResolutionResult = r - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - return fmt.Errorf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - - { - var wantEvents []testEntryEventInfo - - for _, removedEntry := range removed { - wantEvents = append(wantEvents, testEntryEventInfo{ - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }) - } - - wantEvents = append(wantEvents, testEntryEventInfo{ - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAt: clock.Now(), - }, - }) - - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - clock.Advance(typicalLatency) - - select { - case <-ch: - default: - return fmt.Errorf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) - } - wantLinkResolutionResult := LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil} - if diff := cmp.Diff(wantLinkResolutionResult, gotLinkResolutionResult); diff != "" { - return fmt.Errorf("got link resolution result mismatch (-want +got):\n%s", diff) - } - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - return nil -} - -func addReachableEntry(nudDisp *testNUDDispatcher, clock *faketime.ManualClock, linkRes *testNeighborResolver, entry NeighborEntry) error { - return addReachableEntryWithRemoved(nudDisp, clock, linkRes, entry, nil /* removed */) -} - -func TestNeighborCacheEntry(t *testing.T) { - c := DefaultNUDConfigurations() - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, c, clock) - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - if _, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err != nil { - t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) - } - - // No more events should have been dispatched. - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestNeighborCacheRemoveEntry(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - linkRes.neigh.removeEntry(entry.Addr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - { - _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - } -} - -type testContext struct { - clock *faketime.ManualClock - linkRes *testNeighborResolver - nudDisp *testNUDDispatcher -} - -func newTestContext(c NUDConfigurations) testContext { - nudDisp := &testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(nudDisp, c, clock) - - return testContext{ - clock: clock, - linkRes: linkRes, - nudDisp: nudDisp, - } -} - -type overflowOptions struct { - startAtEntryIndex uint16 - wantStaticEntries []NeighborEntry -} - -func (c *testContext) overflowCache(opts overflowOptions) error { - // Fill the neighbor cache to capacity to verify the LRU eviction strategy is - // working properly after the entry removal. - for i := opts.startAtEntryIndex; i < c.linkRes.entries.size(); i++ { - var removedEntries []NeighborEntry - - // When beyond the full capacity, the cache will evict an entry as per the - // LRU eviction strategy. Note that the number of static entries should not - // affect the total number of dynamic entries that can be added. - if i >= neighborCacheSize+opts.startAtEntryIndex { - removedEntry, ok := c.linkRes.entries.entry(i - neighborCacheSize) - if !ok { - return fmt.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i-neighborCacheSize) - } - removedEntries = append(removedEntries, removedEntry) - } - - entry, ok := c.linkRes.entries.entry(i) - if !ok { - return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i) - } - if err := addReachableEntryWithRemoved(c.nudDisp, c.clock, c.linkRes, entry, removedEntries); err != nil { - return fmt.Errorf("addReachableEntryWithRemoved(...) = %s", err) - } - } - - // Expect to find only the most recent entries. The order of entries reported - // by entries() is nondeterministic, so entries have to be sorted before - // comparison. - wantUnorderedEntries := opts.wantStaticEntries - for i := c.linkRes.entries.size() - neighborCacheSize; i < c.linkRes.entries.size(); i++ { - entry, ok := c.linkRes.entries.entry(i) - if !ok { - return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i) - } - durationReachableNanos := time.Duration(c.linkRes.entries.size()-i-1) * typicalLatency - wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: c.clock.Now().Add(-durationReachableNanos), - } - wantUnorderedEntries = append(wantUnorderedEntries, wantEntry) - } - - if diff := cmp.Diff(wantUnorderedEntries, c.linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" { - return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) - } - - // No more events should have been dispatched. - c.nudDisp.mu.Lock() - defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.mu.events); diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - return nil -} - -// TestNeighborCacheOverflow verifies that the LRU cache eviction strategy -// respects the dynamic entry count. -func TestNeighborCacheOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - opts := overflowOptions{ - startAtEntryIndex: 0, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -// TestNeighborCacheRemoveEntryThenOverflow verifies that the LRU cache -// eviction strategy respects the dynamic entry count when an entry is removed. -func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - - // Add a dynamic entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") - } - if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - // Remove the entry - c.linkRes.neigh.removeEntry(entry.Addr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - opts := overflowOptions{ - startAtEntryIndex: 0, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -// TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress verifies that -// adding a duplicate static entry with the same link address does not dispatch -// any events. -func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { - config := DefaultNUDConfigurations() - c := newTestContext(config) - - // Add a static entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") - } - staticLinkAddr := entry.LinkAddr + "static" - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - // Add a duplicate static entry with the same link address. - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - - c.nudDisp.mu.Lock() - defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -// TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress verifies that -// adding a duplicate static entry with a different link address dispatches a -// change event. -func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { - config := DefaultNUDConfigurations() - c := newTestContext(config) - - // Add a static entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") - } - staticLinkAddr := entry.LinkAddr + "static" - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - // Add a duplicate entry with a different link address - staticLinkAddr += "duplicate" - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } -} - -// TestNeighborCacheRemoveStaticEntryThenOverflow verifies that the LRU cache -// eviction strategy respects the dynamic entry count when a static entry is -// added then removed. In this case, the dynamic entry count shouldn't have -// been touched. -func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - - // Add a static entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") - } - staticLinkAddr := entry.LinkAddr + "static" - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - // Remove the static entry that was just added - c.linkRes.neigh.removeEntry(entry.Addr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - opts := overflowOptions{ - startAtEntryIndex: 0, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -// TestNeighborCacheOverwriteWithStaticEntryThenOverflow verifies that the LRU -// cache eviction strategy keeps count of the dynamic entry count when an entry -// is overwritten by a static entry. Static entries should not count towards -// the size of the LRU cache. -func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - - // Add a dynamic entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") - } - if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - // Override the entry with a static one using the same address - staticLinkAddr := entry.LinkAddr + "static" - c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: c.clock.Now(), - }, - }, - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - opts := overflowOptions{ - startAtEntryIndex: 1, - wantStaticEntries: []NeighborEntry{ - { - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") - } - c.linkRes.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) - e, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err) - } - want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - } - if diff := cmp.Diff(want, e); diff != "" { - t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff) - } - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - opts := overflowOptions{ - startAtEntryIndex: 1, - wantStaticEntries: []NeighborEntry{ - { - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAt: c.clock.Now(), - }, - }, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -func TestNeighborCacheClear(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - // Add a dynamic entry. - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - // Add a static entry. - linkRes.neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - // Clear should remove both dynamic and static entries. - linkRes.neigh.clear() - - // Remove events dispatched from clear() have no deterministic order so they - // need to be sorted before comparison. - wantUnorderedEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(wantUnorderedEvents, nudDisp.mu.events, unorderedEventsDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -// TestNeighborCacheClearThenOverflow verifies that the LRU cache eviction -// strategy keeps count of the dynamic entry count when all entries are -// cleared. -func TestNeighborCacheClearThenOverflow(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - c := newTestContext(config) - - // Add a dynamic entry - entry, ok := c.linkRes.entries.entry(0) - if !ok { - t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") - } - if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - // Clear the cache. - c.linkRes.neigh.clear() - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: c.clock.Now(), - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - opts := overflowOptions{ - startAtEntryIndex: 0, - } - if err := c.overflowCache(opts); err != nil { - t.Errorf("c.overflowCache(%+v): %s", opts, err) - } -} - -func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { - config := DefaultNUDConfigurations() - // Stay in Reachable so the cache can overflow - config.BaseReachableTime = infiniteDuration - config.MinRandomFactor = 1 - config.MaxRandomFactor = 1 - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - startedAt := clock.Now() - - // The following logic is very similar to overflowCache, but - // periodically refreshes the frequently used entry. - - // Fill the neighbor cache to capacity - for i := uint16(0); i < neighborCacheSize; i++ { - entry, ok := linkRes.entries.entry(i) - if !ok { - t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) - } - if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - } - - frequentlyUsedEntry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - - // Keep adding more entries - for i := uint16(neighborCacheSize); i < linkRes.entries.size(); i++ { - // Periodically refresh the frequently used entry - if i%(neighborCacheSize/2) == 0 { - if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { - t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", frequentlyUsedEntry.Addr, err) - } - } - - entry, ok := linkRes.entries.entry(i) - if !ok { - t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) - } - - // An entry should have been removed, as per the LRU eviction strategy - removedEntry, ok := linkRes.entries.entry(i - neighborCacheSize + 1) - if !ok { - t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i-neighborCacheSize+1) - } - - if err := addReachableEntryWithRemoved(&nudDisp, clock, linkRes, entry, []NeighborEntry{removedEntry}); err != nil { - t.Fatalf("addReachableEntryWithRemoved(...) = %s", err) - } - } - - // Expect to find only the frequently used entry and the most recent entries. - // The order of entries reported by entries() is nondeterministic, so entries - // have to be sorted before comparison. - wantUnsortedEntries := []NeighborEntry{ - { - Addr: frequentlyUsedEntry.Addr, - LinkAddr: frequentlyUsedEntry.LinkAddr, - State: Reachable, - // Can be inferred since the frequently used entry is the first to - // be created and transitioned to Reachable. - UpdatedAt: startedAt.Add(typicalLatency), - }, - } - - for i := linkRes.entries.size() - neighborCacheSize + 1; i < linkRes.entries.size(); i++ { - entry, ok := linkRes.entries.entry(i) - if !ok { - t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) - } - durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency - wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now().Add(-durationReachableNanos), - }) - } - - if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" { - t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) - } - - // No more events should have been dispatched. - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestNeighborCacheConcurrent(t *testing.T) { - const concurrentProcesses = 16 - - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - storeEntries := linkRes.entries.entries() - for _, entry := range storeEntries { - var wg sync.WaitGroup - for r := 0; r < concurrentProcesses; r++ { - wg.Add(1) - go func(entry NeighborEntry) { - defer wg.Done() - switch e, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err.(type) { - case nil, *tcpip.ErrWouldBlock: - default: - t.Errorf("got linkRes.neigh.entry(%s, '', nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{}) - } - }(entry) - } - - // Wait for all goroutines to send a request - wg.Wait() - - // Process all the requests for a single entry concurrently - clock.Advance(typicalLatency) - } - - // All goroutines add in the same order and add more values than can fit in - // the cache. Our eviction strategy requires that the last entries are - // present, up to the size of the neighbor cache, and the rest are missing. - // The order of entries reported by entries() is nondeterministic, so entries - // have to be sorted before comparison. - var wantUnsortedEntries []NeighborEntry - for i := linkRes.entries.size() - neighborCacheSize; i < linkRes.entries.size(); i++ { - entry, ok := linkRes.entries.entry(i) - if !ok { - t.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i) - } - durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency - wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now().Add(-durationReachableNanos), - }) - } - - if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" { - t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) - } -} - -func TestNeighborCacheReplace(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - // Notify of a link address change - var updatedLinkAddr tcpip.LinkAddress - { - entry, ok := linkRes.entries.entry(1) - if !ok { - t.Fatal("got linkRes.entries.entry(1) = _, false, want = true") - } - updatedLinkAddr = entry.LinkAddr - } - linkRes.entries.set(0, updatedLinkAddr) - linkRes.neigh.handleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - - // Requesting the entry again should start neighbor reachability confirmation. - // - // Verify the entry's new link address and the new state. - { - e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) - } - want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Delay, - UpdatedAt: clock.Now(), - } - if diff := cmp.Diff(want, e); diff != "" { - t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) - } - } - - clock.Advance(config.DelayFirstProbeTime + typicalLatency) - - // Verify that the neighbor is now reachable. - { - e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) - } - want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - } - if diff := cmp.Diff(want, e); diff != "" { - t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) - } - } -} - -func TestNeighborCacheResolutionFailed(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - - var requestCount uint32 - linkRes.onLinkAddressRequest = func() { - atomic.AddUint32(&requestCount, 1) - } - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - - // First, sanity check that resolution is working - if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { - t.Fatalf("addReachableEntry(...) = %s", err) - } - - got, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) - } - want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - } - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) - } - - // Verify address resolution fails for an unknown address. - before := atomic.LoadUint32(&requestCount) - - entry.Addr += "2" - { - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{Err: &tcpip.ErrTimeout{}}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) - clock.Advance(waitFor) - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) - } - } - - maxAttempts := linkRes.neigh.config().MaxUnicastProbes - if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want { - t.Errorf("got link address request count = %d, want = %d", got, want) - } -} - -// TestNeighborCacheResolutionTimeout simulates sending MaxMulticastProbes -// probes and not retrieving a confirmation before the duration defined by -// MaxMulticastProbes * RetransmitTimer. -func TestNeighborCacheResolutionTimeout(t *testing.T) { - config := DefaultNUDConfigurations() - config.RetransmitTimer = time.Millisecond // small enough to cause timeout - - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(nil, config, clock) - // large enough to cause timeout - linkRes.delay = time.Minute - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{Err: &tcpip.ErrTimeout{}}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) - clock.Advance(waitFor) - - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) - } -} - -// TestNeighborCacheRetryResolution simulates retrying communication after -// failing to perform address resolution. -func TestNeighborCacheRetryResolution(t *testing.T) { - config := DefaultNUDConfigurations() - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, config, clock) - // Simulate a faulty link. - linkRes.dropReplies = true - - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") - } - - // Perform address resolution with a faulty link, which will fail. - { - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{Err: &tcpip.ErrTimeout{}}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) - clock.Advance(waitFor) - - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) - } - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - { - wantEntries := []NeighborEntry{ - { - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - UpdatedAt: clock.Now(), - }, - } - if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, unorderedEntriesDiffOpts()...); diff != "" { - t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) - } - } - } - - // Retry address resolution with a working link. - linkRes.dropReplies = false - { - incompleteEntry, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - if incompleteEntry.State != Incomplete { - t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) - } - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - clock.Advance(typicalLatency) - - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) - } - - { - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - { - gotEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Fatalf("linkRes.neigh.entry(%s, '', _): %s", entry.Addr, err) - } - - wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - } - if diff := cmp.Diff(gotEntry, wantEntry); diff != "" { - t.Fatalf("neighbor entry mismatch (-got, +want):\n%s", diff) - } - } - } -} - -func BenchmarkCacheClear(b *testing.B) { - b.StopTimer() - config := DefaultNUDConfigurations() - clock := tcpip.NewStdClock() - linkRes := newTestNeighborResolver(nil, config, clock) - linkRes.delay = 0 - - // Clear for every possible size of the cache - for cacheSize := uint16(0); cacheSize < neighborCacheSize; cacheSize++ { - // Fill the neighbor cache to capacity. - for i := uint16(0); i < cacheSize; i++ { - entry, ok := linkRes.entries.entry(i) - if !ok { - b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) - } - - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" { - b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - b.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - - select { - case <-ch: - default: - b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) - } - } - - b.StartTimer() - linkRes.neigh.clear() - b.StopTimer() - } -} diff --git a/pkg/tcpip/stack/neighbor_entry_list.go b/pkg/tcpip/stack/neighbor_entry_list.go new file mode 100644 index 000000000..d78430080 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_entry_list.go @@ -0,0 +1,221 @@ +package stack + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type neighborEntryElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (neighborEntryElementMapper) linkerFor(elem *neighborEntry) *neighborEntry { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type neighborEntryList struct { + head *neighborEntry + tail *neighborEntry +} + +// Reset resets list l to the empty state. +func (l *neighborEntryList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *neighborEntryList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *neighborEntryList) Front() *neighborEntry { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *neighborEntryList) Back() *neighborEntry { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *neighborEntryList) Len() (count int) { + for e := l.Front(); e != nil; e = (neighborEntryElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *neighborEntryList) PushFront(e *neighborEntry) { + linker := neighborEntryElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + neighborEntryElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *neighborEntryList) PushBack(e *neighborEntry) { + linker := neighborEntryElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + neighborEntryElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *neighborEntryList) PushBackList(m *neighborEntryList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + neighborEntryElementMapper{}.linkerFor(l.tail).SetNext(m.head) + neighborEntryElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *neighborEntryList) InsertAfter(b, e *neighborEntry) { + bLinker := neighborEntryElementMapper{}.linkerFor(b) + eLinker := neighborEntryElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + neighborEntryElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *neighborEntryList) InsertBefore(a, e *neighborEntry) { + aLinker := neighborEntryElementMapper{}.linkerFor(a) + eLinker := neighborEntryElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + neighborEntryElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *neighborEntryList) Remove(e *neighborEntry) { + linker := neighborEntryElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + neighborEntryElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + neighborEntryElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type neighborEntryEntry struct { + next *neighborEntry + prev *neighborEntry +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *neighborEntryEntry) Next() *neighborEntry { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *neighborEntryEntry) Prev() *neighborEntry { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *neighborEntryEntry) SetNext(elem *neighborEntry) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *neighborEntryEntry) SetPrev(elem *neighborEntry) { + e.prev = elem +} diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go deleted file mode 100644 index 59d86d6d4..000000000 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ /dev/null @@ -1,2269 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - "math" - "math/rand" - "sync" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -const ( - entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - - entryTestNICID tcpip.NICID = 1 - - entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") - entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") -) - -var ( - entryTestAddr1 = testutil.MustParse6("a::1") - entryTestAddr2 = testutil.MustParse6("a::2") -) - -// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current -// time. -func runImmediatelyScheduledJobs(clock *faketime.ManualClock) { - clock.Advance(immediateDuration) -} - -// The following unit tests exercise every state transition and verify its -// behavior with RFC 4681 and RFC 7048. -// -// | From | To | Cause | Update | Action | Event | -// | =========== | =========== | ========================================== | ======== | ===========| ======= | -// | Unknown | Unknown | Confirmation w/ unknown address | | | Added | -// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added | -// | Unknown | Stale | Probe | | | Added | -// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed | -// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed | -// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed | -// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed | -// | Incomplete | Unreachable | Max probes sent without reply | | Notify | Changed | -// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | | -// | Reachable | Stale | Reachable timer expired | | | Changed | -// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed | -// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed | -// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed | -// | Stale | Stale | Override confirmation | LinkAddr | | Changed | -// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed | -// | Stale | Delay | Packet sent | | | Changed | -// | Delay | Reachable | Upper-layer confirmation | | | Changed | -// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed | -// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed | -// | Delay | Stale | Probe or confirmation w/ different address | | | Changed | -// | Delay | Probe | Delay timer expired | | Send probe | Changed | -// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed | -// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed | -// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed | -// | Probe | Stale | Probe or confirmation w/ different address | | | Changed | -// | Probe | Probe | Retransmit timer expired | | | Changed | -// | Probe | Unreachable | Max probes sent without reply | | Notify | Changed | -// | Unreachable | Incomplete | Packet queued | | Send probe | Changed | -// | Unreachable | Stale | Probe w/ different address | LinkAddr | | Changed | - -type testEntryEventType uint8 - -const ( - entryTestAdded testEntryEventType = iota - entryTestChanged - entryTestRemoved -) - -func (t testEntryEventType) String() string { - switch t { - case entryTestAdded: - return "add" - case entryTestChanged: - return "change" - case entryTestRemoved: - return "remove" - default: - return fmt.Sprintf("unknown (%d)", t) - } -} - -// Fields are exported for use with cmp.Diff. -type testEntryEventInfo struct { - EventType testEntryEventType - NICID tcpip.NICID - Entry NeighborEntry -} - -func (e testEntryEventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, %#v", e.EventType, e.NICID, e.Entry) -} - -// testNUDDispatcher implements NUDDispatcher to validate the dispatching of -// events upon certain NUD state machine events. -type testNUDDispatcher struct { - mu struct { - sync.Mutex - events []testEntryEventInfo - } -} - -var _ NUDDispatcher = (*testNUDDispatcher)(nil) - -func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) { - d.mu.Lock() - defer d.mu.Unlock() - d.mu.events = append(d.mu.events, e) -} - -func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry NeighborEntry) { - d.queueEvent(testEntryEventInfo{ - EventType: entryTestAdded, - NICID: nicID, - Entry: entry, - }) -} - -func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry NeighborEntry) { - d.queueEvent(testEntryEventInfo{ - EventType: entryTestChanged, - NICID: nicID, - Entry: entry, - }) -} - -func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry NeighborEntry) { - d.queueEvent(testEntryEventInfo{ - EventType: entryTestRemoved, - NICID: nicID, - Entry: entry, - }) -} - -type entryTestLinkResolver struct { - mu struct { - sync.Mutex - probes []entryTestProbeInfo - } -} - -var _ LinkAddressResolver = (*entryTestLinkResolver)(nil) - -type entryTestProbeInfo struct { - RemoteAddress tcpip.Address - RemoteLinkAddress tcpip.LinkAddress - LocalAddress tcpip.Address -} - -func (p entryTestProbeInfo) String() string { - return fmt.Sprintf("probe with RemoteAddress=%q, RemoteLinkAddress=%q, LocalAddress=%q", p.RemoteAddress, p.RemoteLinkAddress, p.LocalAddress) -} - -// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts -// to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { - r.mu.Lock() - defer r.mu.Unlock() - r.mu.probes = append(r.mu.probes, entryTestProbeInfo{ - RemoteAddress: targetAddr, - RemoteLinkAddress: linkAddr, - LocalAddress: localAddr, - }) - return nil -} - -// ResolveStaticAddress attempts to resolve address without sending requests. -// It either resolves the name immediately or returns the empty LinkAddress. -func (*entryTestLinkResolver) ResolveStaticAddress(tcpip.Address) (tcpip.LinkAddress, bool) { - return "", false -} - -// LinkAddressProtocol returns the network protocol of the addresses this -// resolver can resolve. -func (*entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return entryTestNetNumber -} - -func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) { - clock := faketime.NewManualClock() - disp := testNUDDispatcher{} - nic := nic{ - LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint - - id: entryTestNICID, - stack: &Stack{ - clock: clock, - nudDisp: &disp, - nudConfigs: c, - randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())), - }, - stats: makeNICStats(tcpip.NICStats{}.FillIn()), - } - netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) - nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ - header.IPv6ProtocolNumber: netEP, - } - - var linkRes entryTestLinkResolver - // Stub out the neighbor cache to verify deletion from the cache. - l := &linkResolver{ - resolver: &linkRes, - } - l.neigh.init(&nic, &linkRes) - - entry := newNeighborEntry(&l.neigh, entryTestAddr1 /* remoteAddr */, l.neigh.state) - l.neigh.mu.Lock() - l.neigh.mu.cache[entryTestAddr1] = entry - l.neigh.mu.Unlock() - nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]*linkResolver{ - header.IPv6ProtocolNumber: l, - } - - return entry, &disp, &linkRes, clock -} - -// TestEntryInitiallyUnknown verifies that the state of a newly created -// neighborEntry is Unknown. -func TestEntryInitiallyUnknown(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - if e.mu.neigh.State != Unknown { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown) - } - e.mu.Unlock() - - clock.Advance(c.RetransmitTimer) - - // No probes should have been sent. - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Unknown { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown) - } - e.mu.Unlock() - - clock.Advance(time.Hour) - - // No probes should have been sent. - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryUnknownToIncomplete(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } -} - -func unknownToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - if err := func() error { - e.mu.Lock() - defer e.mu.Unlock() - - if e.mu.neigh.State != Unknown { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown) - } - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Incomplete { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - return nil - }(); err != nil { - return err - } - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.mu.probes) - linkRes.mu.probes = nil - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - UpdatedAt: clock.Now(), - }, - }, - } - { - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - return nil -} - -func TestEntryUnknownToStale(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } -} - -func unknownToStale(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - if err := func() error { - e.mu.Lock() - defer e.mu.Unlock() - - if e.mu.neigh.State != Unknown { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown) - } - e.handleProbeLocked(entryTestLinkAddr1) - if e.mu.neigh.State != Stale { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - return nil - }(); err != nil { - return err - } - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - { - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - return nil -} - -func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - - // UpdatedAt should remain the same during address resolution. - e.mu.Lock() - startedAt := e.mu.neigh.UpdatedAt - e.mu.Unlock() - - // Wait for the rest of the reachability probe transmissions, signifying - // Incomplete to Incomplete transitions. - for i := uint32(1); i < c.MaxMulticastProbes; i++ { - clock.Advance(c.RetransmitTimer) - - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.mu.probes) - linkRes.mu.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - e.mu.Lock() - if got, want := e.mu.neigh.UpdatedAt, startedAt; got != want { - t.Errorf("got e.mu.neigh.UpdatedAt = %q, want = %q", got, want) - } - e.mu.Unlock() - } - - // UpdatedAt should change after failing address resolution. Timing out after - // sending the last probe transitions the entry to Unreachable. - clock.Advance(c.RetransmitTimer) - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryIncompleteToReachable(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachable(...) = %s", err) - } -} - -func incompleteToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock, flags ReachabilityConfirmationFlags) error { - if err := func() error { - e.mu.Lock() - defer e.mu.Unlock() - - if e.mu.neigh.State != Incomplete { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - e.handleConfirmationLocked(entryTestLinkAddr1, flags) - if e.mu.neigh.State != Reachable { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - return fmt.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - return nil - }(); err != nil { - return err - } - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - { - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - return nil -} - -func incompleteToReachable(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - if err := incompleteToReachableWithFlags(e, nudDisp, linkRes, clock, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }); err != nil { - return err - } - - e.mu.Lock() - isRouter := e.mu.isRouter - e.mu.Unlock() - if isRouter { - return fmt.Errorf("got e.mu.isRouter = %t, want = false", isRouter) - } - - return nil -} - -func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachableWithRouterFlag(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachableWithRouterFlag(...) = %s", err) - } -} - -func incompleteToReachableWithRouterFlag(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - if err := incompleteToReachableWithFlags(e, nudDisp, linkRes, clock, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: true, - }); err != nil { - return err - } - - e.mu.Lock() - isRouter := e.mu.isRouter - e.mu.Unlock() - if !isRouter { - return fmt.Errorf("got e.mu.isRouter = %t, want = true", isRouter) - } - - return nil -} - -func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - if e.mu.isRouter { - t.Errorf("got e.mu.isRouter = %t, want = false", e.mu.isRouter) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr1) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryIncompleteToUnreachable(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToUnreachable(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToUnreachable(...) = %s", err) - } -} - -func incompleteToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Incomplete { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Incomplete) - } - } - - // The first probe was sent in the transition from Unknown to Incomplete. - clock.Advance(c.RetransmitTimer) - - // Observe each subsequent multicast probe transmitted. - for i := uint32(1); i < c.MaxMulticastProbes; i++ { - wantProbes := []entryTestProbeInfo{{ - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: "", - LocalAddress: entryTestAddr2, - }} - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.mu.probes) - linkRes.mu.probes = nil - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probe #%d mismatch (-want, +got):\n%s", i+1, diff) - } - - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Incomplete { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Incomplete) - } - - clock.Advance(c.RetransmitTimer) - } - - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Unreachable { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Unreachable) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - return nil -} - -type testLocker struct{} - -var _ sync.Locker = (*testLocker)(nil) - -func (*testLocker) Lock() {} -func (*testLocker) Unlock() {} - -func TestEntryReachableToReachableClearsRouterWhenConfirmationWithoutRouter(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachableWithRouterFlag(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachableWithRouterFlag(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if got, want := e.mu.isRouter, false; got != want { - t.Errorf("got e.mu.isRouter = %t, want = %t", got, want) - } - ipv6EP := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) - if ipv6EP.invalidatedRtr != e.mu.neigh.Addr { - t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.mu.neigh.Addr) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events) - nudDisp.mu.Unlock() - if diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestEntryReachableToReachableWhenProbeWithSameAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachable(...) = %s", err) - } - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr1) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events) - nudDisp.mu.Unlock() - if diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestEntryReachableToStaleWhenTimeout(t *testing.T) { - c := DefaultNUDConfigurations() - // Eliminate random factors from ReachableTime computation so the transition - // from Stale to Reachable will only take BaseReachableTime duration. - c.MinRandomFactor = 1 - c.MaxRandomFactor = 1 - - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachable(...) = %s", err) - } - if err := reachableToStale(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("reachableToStale(...) = %s", err) - } -} - -// reachableToStale transitions a neighborEntry in Reachable state to Stale -// state. Depends on the elimination of random factors in the ReachableTime -// computation. -// -// c.MinRandomFactor = 1 -// c.MaxRandomFactor = 1 -func reachableToStale(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - // Ensure there are no random factors in the ReachableTime computation. - if c.MinRandomFactor != 1 { - return fmt.Errorf("got c.MinRandomFactor = %f, want = 1", c.MinRandomFactor) - } - if c.MaxRandomFactor != 1 { - return fmt.Errorf("got c.MaxRandomFactor = %f, want = 1", c.MaxRandomFactor) - } - - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Reachable { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Reachable) - } - } - - clock.Advance(c.BaseReachableTime) - - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Stale { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Stale) - } - } - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - { - - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - return nil -} - -func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachable(...) = %s", err) - } - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr2) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachable(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachable(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToStaleWhenProbeWithSameAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr1) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) - } - e.mu.Unlock() - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr2) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryStaleToDelay(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } -} - -func staleToDelay(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - if err := func() error { - e.mu.Lock() - defer e.mu.Unlock() - - if e.mu.neigh.State != Stale { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Delay { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) - } - return nil - }(); err != nil { - return err - } - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - return nil -} - -func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - - e.mu.Lock() - e.handleUpperLevelConfirmationLocked() - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if e.mu.neigh.State != Reachable { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Delay { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) - } - if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr2) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryDelayToProbe(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } -} - -func delayToProbe(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Delay { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Delay) - } - } - - // Wait for the first unicast probe to be transmitted, marking the - // transition from Delay to Probe. - clock.Advance(c.DelayFirstProbeTime) - - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Probe { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Probe) - } - } - - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - { - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.mu.probes) - linkRes.mu.probes = nil - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - UpdatedAt: clock.Now(), - }, - }, - } - { - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - return nil -} - -func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } - - e.mu.Lock() - e.handleProbeLocked(entryTestLinkAddr2) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Stale { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - nudDisp.mu.Unlock() -} - -func TestEntryProbeToProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if e.mu.neigh.State != Probe { - t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) - } - if got, want := e.mu.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", got, want) - } - e.mu.Unlock() - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events) - nudDisp.mu.Unlock() - if diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } -} - -func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } - if err := probeToReachableWithOverride(e, nudDisp, linkRes, clock, entryTestLinkAddr2); err != nil { - t.Fatalf("probeToReachableWithOverride(...) = %s", err) - } -} - -func probeToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) error { - if err := func() error { - e.mu.Lock() - defer e.mu.Unlock() - - prevLinkAddr := e.mu.neigh.LinkAddr - if e.mu.neigh.State != Probe { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) - } - e.handleConfirmationLocked(linkAddr, flags) - - if e.mu.neigh.State != Reachable { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) - } - if linkAddr == "" { - linkAddr = prevLinkAddr - } - if e.mu.neigh.LinkAddr != linkAddr { - return fmt.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, linkAddr) - } - return nil - }(); err != nil { - return err - } - - // No probes should have been sent. - runImmediatelyScheduledJobs(clock) - { - linkRes.mu.Lock() - diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes) - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: linkAddr, - State: Reachable, - UpdatedAt: clock.Now(), - }, - }, - } - { - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - - return nil -} - -func probeToReachableWithOverride(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock, linkAddr tcpip.LinkAddress) error { - return probeToReachableWithFlags(e, nudDisp, linkRes, clock, linkAddr, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) -} - -func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } - if err := probeToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("probeToReachable(...) = %s", err) - } -} - -func probeToReachable(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - return probeToReachableWithFlags(e, nudDisp, linkRes, clock, entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) -} - -func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { - c := DefaultNUDConfigurations() - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } - if err := probeToReachableWithoutAddress(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("probeToReachableWithoutAddress(...) = %s", err) - } -} - -func probeToReachableWithoutAddress(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - return probeToReachableWithFlags(e, nudDisp, linkRes, clock, "" /* linkAddr */, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) -} - -func TestEntryProbeToUnreachable(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - c.MaxUnicastProbes = 3 - c.DelayFirstProbeTime = c.RetransmitTimer - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToStale(...) = %s", err) - } - if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("staleToDelay(...) = %s", err) - } - if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("delayToProbe(...) = %s", err) - } - if err := probeToUnreachable(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("probeToUnreachable(...) = %s", err) - } -} - -func probeToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Probe { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Probe) - } - } - - // The first probe was sent in the transition from Delay to Probe. - clock.Advance(c.RetransmitTimer) - - // Observe each subsequent unicast probe transmitted. - for i := uint32(1); i < c.MaxUnicastProbes; i++ { - wantProbes := []entryTestProbeInfo{{ - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }} - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.mu.probes) - linkRes.mu.probes = nil - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probe #%d mismatch (-want, +got):\n%s", i+1, diff) - } - - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Probe { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Probe) - } - - clock.Advance(c.RetransmitTimer) - } - - { - e.mu.Lock() - state := e.mu.neigh.State - e.mu.Unlock() - if state != Unreachable { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Unreachable) - } - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Unreachable, - UpdatedAt: clock.Now(), - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - - return nil -} - -func TestEntryUnreachableToIncomplete(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToUnreachable(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToUnreachable(...) = %s", err) - } - if err := unreachableToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unreachableToIncomplete(...) = %s", err) - } -} - -func unreachableToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error { - if err := func() error { - e.mu.Lock() - defer e.mu.Unlock() - - if e.mu.neigh.State != Unreachable { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable) - } - e.handlePacketQueuedLocked(entryTestAddr2) - if e.mu.neigh.State != Incomplete { - return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) - } - return nil - }(); err != nil { - return err - } - - runImmediatelyScheduledJobs(clock) - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(wantProbes, linkRes.mu.probes) - linkRes.mu.probes = nil - linkRes.mu.Unlock() - if diff != "" { - return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - UpdatedAt: clock.Now(), - }, - }, - } - { - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) - } - } - return nil -} - -func TestEntryUnreachableToStale(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - // Eliminate random factors from ReachableTime computation so the transition - // from Stale to Reachable will only take BaseReachableTime duration. - c.MinRandomFactor = 1 - c.MaxRandomFactor = 1 - - e, nudDisp, linkRes, clock := entryTestSetup(c) - - if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unknownToIncomplete(...) = %s", err) - } - if err := incompleteToUnreachable(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToUnreachable(...) = %s", err) - } - if err := unreachableToIncomplete(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("unreachableToIncomplete(...) = %s", err) - } - if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("incompleteToReachable(...) = %s", err) - } - if err := reachableToStale(c, e, nudDisp, linkRes, clock); err != nil { - t.Fatalf("reachableToStale(...) = %s", err) - } -} diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go deleted file mode 100644 index 5cb342f78..000000000 --- a/pkg/tcpip/stack/nic_test.go +++ /dev/null @@ -1,224 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/testutil" -) - -var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) -var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) -var _ NDPEndpoint = (*testIPv6Endpoint)(nil) - -// An IPv6 NetworkEndpoint that throws away outgoing packets. -// -// We use this instead of ipv6.endpoint because the ipv6 package depends on -// the stack package which this test lives in, causing a cyclic dependency. -type testIPv6Endpoint struct { - AddressableEndpointState - - nic NetworkInterface - protocol *testIPv6Protocol - - invalidatedRtr tcpip.Address -} - -func (*testIPv6Endpoint) Enable() tcpip.Error { - return nil -} - -func (*testIPv6Endpoint) Enabled() bool { - return true -} - -func (*testIPv6Endpoint) Disable() {} - -// DefaultTTL implements NetworkEndpoint.DefaultTTL. -func (*testIPv6Endpoint) DefaultTTL() uint8 { - return 0 -} - -// MTU implements NetworkEndpoint.MTU. -func (e *testIPv6Endpoint) MTU() uint32 { - return e.nic.MTU() - header.IPv6MinimumSize -} - -// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. -func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { - return e.nic.MaxHeaderLength() + header.IPv6MinimumSize -} - -// WritePacket implements NetworkEndpoint.WritePacket. -func (*testIPv6Endpoint) WritePacket(*Route, NetworkHeaderParams, *PacketBuffer) tcpip.Error { - return nil -} - -// WritePackets implements NetworkEndpoint.WritePackets. -func (*testIPv6Endpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { - // Our tests don't use this so we don't support it. - return 0, &tcpip.ErrNotSupported{} -} - -// WriteHeaderIncludedPacket implements -// NetworkEndpoint.WriteHeaderIncludedPacket. -func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) tcpip.Error { - // Our tests don't use this so we don't support it. - return &tcpip.ErrNotSupported{} -} - -// HandlePacket implements NetworkEndpoint.HandlePacket. -func (*testIPv6Endpoint) HandlePacket(*PacketBuffer) {} - -// Close implements NetworkEndpoint.Close. -func (e *testIPv6Endpoint) Close() { - e.AddressableEndpointState.Cleanup() -} - -// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber. -func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return header.IPv6ProtocolNumber -} - -func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { - e.invalidatedRtr = rtr -} - -// Stats implements NetworkEndpoint. -func (*testIPv6Endpoint) Stats() NetworkEndpointStats { - return &testIPv6EndpointStats{} -} - -var _ NetworkEndpointStats = (*testIPv6EndpointStats)(nil) - -type testIPv6EndpointStats struct{} - -// IsNetworkEndpointStats implements stack.NetworkEndpointStats. -func (*testIPv6EndpointStats) IsNetworkEndpointStats() {} - -// We use this instead of ipv6.protocol because the ipv6 package depends on -// the stack package which this test lives in, causing a cyclic dependency. -type testIPv6Protocol struct{} - -// Number implements NetworkProtocol.Number. -func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber { - return header.IPv6ProtocolNumber -} - -// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize. -func (*testIPv6Protocol) MinimumPacketSize() int { - return header.IPv6MinimumSize -} - -// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen. -func (*testIPv6Protocol) DefaultPrefixLen() int { - return header.IPv6AddressSize * 8 -} - -// ParseAddresses implements NetworkProtocol.ParseAddresses. -func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - h := header.IPv6(v) - return h.SourceAddress(), h.DestinationAddress() -} - -// NewEndpoint implements NetworkProtocol.NewEndpoint. -func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ TransportDispatcher) NetworkEndpoint { - e := &testIPv6Endpoint{ - nic: nic, - protocol: p, - } - e.AddressableEndpointState.Init(e) - return e -} - -// SetOption implements NetworkProtocol.SetOption. -func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { - return nil -} - -// Option implements NetworkProtocol.Option. -func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { - return nil -} - -// Close implements NetworkProtocol.Close. -func (*testIPv6Protocol) Close() {} - -// Wait implements NetworkProtocol.Wait. -func (*testIPv6Protocol) Wait() {} - -// Parse implements NetworkProtocol.Parse. -func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - return 0, false, false -} - -func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { - // When the NIC is disabled, the only field that matters is the stats field. - // This test is limited to stats counter checks. - nic := nic{ - stats: makeNICStats(tcpip.NICStats{}.FillIn()), - } - - if got := nic.stats.local.DisabledRx.Packets.Value(); got != 0 { - t.Errorf("got DisabledRx.Packets = %d, want = 0", got) - } - if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 0 { - t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) - } - if got := nic.stats.local.Rx.Packets.Value(); got != 0 { - t.Errorf("got Rx.Packets = %d, want = 0", got) - } - if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { - t.Errorf("got Rx.Bytes = %d, want = 0", got) - } - - if t.Failed() { - t.FailNow() - } - - nic.DeliverNetworkPacket("", "", 0, NewPacketBuffer(PacketBufferOptions{ - Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), - })) - - if got := nic.stats.local.DisabledRx.Packets.Value(); got != 1 { - t.Errorf("got DisabledRx.Packets = %d, want = 1", got) - } - if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 4 { - t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) - } - if got := nic.stats.local.Rx.Packets.Value(); got != 0 { - t.Errorf("got Rx.Packets = %d, want = 0", got) - } - if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { - t.Errorf("got Rx.Bytes = %d, want = 0", got) - } -} - -func TestMultiCounterStatsInitialization(t *testing.T) { - global := tcpip.NICStats{}.FillIn() - nic := nic{ - stats: makeNICStats(global), - } - multi := nic.stats.multiCounterNICStats - local := nic.stats.local - if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&multi).Elem(), []reflect.Value{reflect.ValueOf(&local).Elem(), reflect.ValueOf(&global).Elem()}); err != nil { - t.Error(err) - } -} diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go deleted file mode 100644 index 1aeb2f8a5..000000000 --- a/pkg/tcpip/stack/nud_test.go +++ /dev/null @@ -1,816 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "math" - "math/rand" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - defaultBaseReachableTime = 30 * time.Second - minimumBaseReachableTime = time.Millisecond - defaultMinRandomFactor = 0.5 - defaultMaxRandomFactor = 1.5 - defaultRetransmitTimer = time.Second - minimumRetransmitTimer = time.Millisecond - defaultDelayFirstProbeTime = 5 * time.Second - defaultMaxMulticastProbes = 3 - defaultMaxUnicastProbes = 3 - - defaultFakeRandomNum = 0.5 -) - -// fakeRand is a deterministic random number generator. -type fakeRand struct { - num float32 -} - -var _ rand.Source = (*fakeRand)(nil) - -func (f *fakeRand) Int63() int64 { - return int64(f.num * float32(1<<63)) -} - -func (*fakeRand) Seed(int64) {} - -func TestNUDFunctions(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - nicID tcpip.NICID - netProtoFactory []stack.NetworkProtocolFactory - extraLinkCapabilities stack.LinkEndpointCapabilities - expectedErr tcpip.Error - }{ - { - name: "Invalid NICID", - nicID: nicID + 1, - netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - extraLinkCapabilities: stack.CapabilityResolutionRequired, - expectedErr: &tcpip.ErrUnknownNICID{}, - }, - { - name: "No network protocol", - nicID: nicID, - expectedErr: &tcpip.ErrNotSupported{}, - }, - { - name: "With IPv6", - nicID: nicID, - netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - expectedErr: &tcpip.ErrNotSupported{}, - }, - { - name: "With resolution capability", - nicID: nicID, - extraLinkCapabilities: stack.CapabilityResolutionRequired, - expectedErr: &tcpip.ErrNotSupported{}, - }, - { - name: "With IPv6 and resolution capability", - nicID: nicID, - netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - extraLinkCapabilities: stack.CapabilityResolutionRequired, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), - NetworkProtocols: test.netProtoFactory, - Clock: clock, - }) - - e := channel.New(0, 0, linkAddr1) - e.LinkEPCapabilities &^= stack.CapabilityResolutionRequired - e.LinkEPCapabilities |= test.extraLinkCapabilities - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - configs := stack.DefaultNUDConfigurations() - configs.BaseReachableTime = time.Hour - - { - err := s.SetNUDConfigurations(test.nicID, ipv6.ProtocolNumber, configs) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Errorf("s.SetNUDConfigurations(%d, %d, _) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) - } - } - - { - gotConfigs, err := s.NUDConfigurations(test.nicID, ipv6.ProtocolNumber) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Errorf("s.NUDConfigurations(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) - } else if test.expectedErr == nil { - if diff := cmp.Diff(configs, gotConfigs); diff != "" { - t.Errorf("got configs mismatch (-want +got):\n%s", diff) - } - } - } - - for _, addr := range []tcpip.Address{llAddr1, llAddr2} { - { - err := s.AddStaticNeighbor(test.nicID, ipv6.ProtocolNumber, addr, linkAddr1) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Errorf("s.AddStaticNeighbor(%d, %d, %s, %s) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, addr, linkAddr1, diff) - } - } - } - - { - wantErr := test.expectedErr - for i := 0; i < 2; i++ { - { - err := s.RemoveNeighbor(test.nicID, ipv6.ProtocolNumber, llAddr1) - if diff := cmp.Diff(wantErr, err); diff != "" { - t.Errorf("s.RemoveNeighbor(%d, %d, '') error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) - } - } - - if test.expectedErr != nil { - break - } - - // Removing a neighbor that does not exist should give us a bad address - // error. - wantErr = &tcpip.ErrBadAddress{} - } - } - - { - neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) - } else if test.expectedErr == nil { - if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAt: clock.Now()}}, - neighbors, - ); diff != "" { - t.Errorf("neighbors mismatch (-want +got):\n%s", diff) - } - } - } - - { - err := s.ClearNeighbors(test.nicID, ipv6.ProtocolNumber) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Errorf("s.ClearNeigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) - } else if test.expectedErr == nil { - if neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber); err != nil { - t.Errorf("s.Neighbors(%d, %d): %s", test.nicID, ipv6.ProtocolNumber, err) - } else if len(neighbors) != 0 { - t.Errorf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) - } - } - } - }) - } -} - -func TestDefaultNUDConfigurations(t *testing.T) { - const nicID = 1 - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The networking - // stack will only allocate neighbor caches if a protocol providing link - // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: stack.DefaultNUDConfigurations(), - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - c, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got, want := c, stack.DefaultNUDConfigurations(); got != want { - t.Errorf("got stack.NUDConfigurations(%d, %d) = %+v, want = %+v", nicID, ipv6.ProtocolNumber, got, want) - } -} - -func TestNUDConfigurationsBaseReachableTime(t *testing.T) { - tests := []struct { - name string - baseReachableTime time.Duration - want time.Duration - }{ - // Invalid cases - { - name: "EqualToZero", - baseReachableTime: 0, - want: defaultBaseReachableTime, - }, - // Valid cases - { - name: "MoreThanZero", - baseReachableTime: time.Millisecond, - want: time.Millisecond, - }, - { - name: "MoreThanDefaultBaseReachableTime", - baseReachableTime: 2 * defaultBaseReachableTime, - want: 2 * defaultBaseReachableTime, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.BaseReachableTime = test.baseReachableTime - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.BaseReachableTime; got != test.want { - t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsMinRandomFactor(t *testing.T) { - tests := []struct { - name string - minRandomFactor float32 - want float32 - }{ - // Invalid cases - { - name: "LessThanZero", - minRandomFactor: -1, - want: defaultMinRandomFactor, - }, - { - name: "EqualToZero", - minRandomFactor: 0, - want: defaultMinRandomFactor, - }, - // Valid cases - { - name: "MoreThanZero", - minRandomFactor: 1, - want: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.MinRandomFactor = test.minRandomFactor - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.MinRandomFactor; got != test.want { - t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { - tests := []struct { - name string - minRandomFactor float32 - maxRandomFactor float32 - want float32 - }{ - // Invalid cases - { - name: "LessThanZero", - minRandomFactor: defaultMinRandomFactor, - maxRandomFactor: -1, - want: defaultMaxRandomFactor, - }, - { - name: "EqualToZero", - minRandomFactor: defaultMinRandomFactor, - maxRandomFactor: 0, - want: defaultMaxRandomFactor, - }, - { - name: "LessThanMinRandomFactor", - minRandomFactor: defaultMinRandomFactor, - maxRandomFactor: defaultMinRandomFactor * 0.99, - want: defaultMaxRandomFactor, - }, - { - name: "MoreThanMinRandomFactorWhenMinRandomFactorIsLargerThanMaxRandomFactorDefault", - minRandomFactor: defaultMaxRandomFactor * 2, - maxRandomFactor: defaultMaxRandomFactor, - want: defaultMaxRandomFactor * 6, - }, - // Valid cases - { - name: "EqualToMinRandomFactor", - minRandomFactor: defaultMinRandomFactor, - maxRandomFactor: defaultMinRandomFactor, - want: defaultMinRandomFactor, - }, - { - name: "MoreThanMinRandomFactor", - minRandomFactor: defaultMinRandomFactor, - maxRandomFactor: defaultMinRandomFactor * 1.1, - want: defaultMinRandomFactor * 1.1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.MinRandomFactor = test.minRandomFactor - c.MaxRandomFactor = test.maxRandomFactor - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.MaxRandomFactor; got != test.want { - t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsRetransmitTimer(t *testing.T) { - tests := []struct { - name string - retransmitTimer time.Duration - want time.Duration - }{ - // Invalid cases - { - name: "EqualToZero", - retransmitTimer: 0, - want: defaultRetransmitTimer, - }, - { - name: "LessThanMinimumRetransmitTimer", - retransmitTimer: minimumRetransmitTimer - time.Nanosecond, - want: defaultRetransmitTimer, - }, - // Valid cases - { - name: "EqualToMinimumRetransmitTimer", - retransmitTimer: minimumRetransmitTimer, - want: minimumBaseReachableTime, - }, - { - name: "LargetThanMinimumRetransmitTimer", - retransmitTimer: 2 * minimumBaseReachableTime, - want: 2 * minimumBaseReachableTime, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.RetransmitTimer = test.retransmitTimer - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.RetransmitTimer; got != test.want { - t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { - tests := []struct { - name string - delayFirstProbeTime time.Duration - want time.Duration - }{ - // Invalid cases - { - name: "EqualToZero", - delayFirstProbeTime: 0, - want: defaultDelayFirstProbeTime, - }, - // Valid cases - { - name: "MoreThanZero", - delayFirstProbeTime: time.Millisecond, - want: time.Millisecond, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.DelayFirstProbeTime = test.delayFirstProbeTime - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.DelayFirstProbeTime; got != test.want { - t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { - tests := []struct { - name string - maxMulticastProbes uint32 - want uint32 - }{ - // Invalid cases - { - name: "EqualToZero", - maxMulticastProbes: 0, - want: defaultMaxMulticastProbes, - }, - // Valid cases - { - name: "MoreThanZero", - maxMulticastProbes: 1, - want: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.MaxMulticastProbes = test.maxMulticastProbes - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.MaxMulticastProbes; got != test.want { - t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want) - } - }) - } -} - -func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { - tests := []struct { - name string - maxUnicastProbes uint32 - want uint32 - }{ - // Invalid cases - { - name: "EqualToZero", - maxUnicastProbes: 0, - want: defaultMaxUnicastProbes, - }, - // Valid cases - { - name: "MoreThanZero", - maxUnicastProbes: 1, - want: 1, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.MaxUnicastProbes = test.maxUnicastProbes - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) - } - if got := sc.MaxUnicastProbes; got != test.want { - t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want) - } - }) - } -} - -// TestNUDStateReachableTime verifies the correctness of the ReachableTime -// computation. -func TestNUDStateReachableTime(t *testing.T) { - tests := []struct { - name string - baseReachableTime time.Duration - minRandomFactor float32 - maxRandomFactor float32 - want time.Duration - }{ - { - name: "AllZeros", - baseReachableTime: 0, - minRandomFactor: 0, - maxRandomFactor: 0, - want: 0, - }, - { - name: "ZeroMaxRandomFactor", - baseReachableTime: time.Second, - minRandomFactor: 0, - maxRandomFactor: 0, - want: 0, - }, - { - name: "ZeroMinRandomFactor", - baseReachableTime: time.Second, - minRandomFactor: 0, - maxRandomFactor: 1, - want: time.Duration(defaultFakeRandomNum * float32(time.Second)), - }, - { - name: "FractionalRandomFactor", - baseReachableTime: time.Duration(math.MaxInt64), - minRandomFactor: 0.001, - maxRandomFactor: 0.002, - want: time.Duration((0.001 + (0.001 * defaultFakeRandomNum)) * float32(math.MaxInt64)), - }, - { - name: "MinAndMaxRandomFactorsEqual", - baseReachableTime: time.Second, - minRandomFactor: 1, - maxRandomFactor: 1, - want: time.Second, - }, - { - name: "MinAndMaxRandomFactorsDifferent", - baseReachableTime: time.Second, - minRandomFactor: 1, - maxRandomFactor: 2, - want: time.Duration((1.0 + defaultFakeRandomNum) * float32(time.Second)), - }, - { - name: "MaxInt64", - baseReachableTime: time.Duration(math.MaxInt64), - minRandomFactor: 1, - maxRandomFactor: 1, - want: time.Duration(math.MaxInt64), - }, - { - name: "Overflow", - baseReachableTime: time.Duration(math.MaxInt64), - minRandomFactor: 1.5, - maxRandomFactor: 1.5, - want: time.Duration(math.MaxInt64), - }, - { - name: "DoubleOverflow", - baseReachableTime: time.Duration(math.MaxInt64), - minRandomFactor: 2.5, - maxRandomFactor: 2.5, - want: time.Duration(math.MaxInt64), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := stack.NUDConfigurations{ - BaseReachableTime: test.baseReachableTime, - MinRandomFactor: test.minRandomFactor, - MaxRandomFactor: test.maxRandomFactor, - } - // A fake random number generator is used to ensure deterministic - // results. - rng := fakeRand{ - num: defaultFakeRandomNum, - } - var clock faketime.NullClock - s := stack.NewNUDState(c, &clock, rand.New(&rng)) - if got, want := s.ReachableTime(), test.want; got != want { - t.Errorf("got ReachableTime = %q, want = %q", got, want) - } - }) - } -} - -// TestNUDStateRecomputeReachableTime exercises the ReachableTime function -// twice to verify recomputation of reachable time when the min random factor, -// max random factor, or base reachable time changes. -func TestNUDStateRecomputeReachableTime(t *testing.T) { - const defaultBase = time.Second - const defaultMin = 2.0 * defaultMaxRandomFactor - const defaultMax = 3.0 * defaultMaxRandomFactor - - tests := []struct { - name string - baseReachableTime time.Duration - minRandomFactor float32 - maxRandomFactor float32 - want time.Duration - }{ - { - name: "BaseReachableTime", - baseReachableTime: 2 * defaultBase, - minRandomFactor: defaultMin, - maxRandomFactor: defaultMax, - want: time.Duration((defaultMin + (defaultMax-defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), - }, - { - name: "MinRandomFactor", - baseReachableTime: defaultBase, - minRandomFactor: defaultMax, - maxRandomFactor: defaultMax, - want: time.Duration(defaultMax * float32(defaultBase)), - }, - { - name: "MaxRandomFactor", - baseReachableTime: defaultBase, - minRandomFactor: defaultMin, - maxRandomFactor: defaultMin, - want: time.Duration(defaultMin * float32(defaultBase)), - }, - { - name: "BothRandomFactor", - baseReachableTime: defaultBase, - minRandomFactor: 2 * defaultMin, - maxRandomFactor: 2 * defaultMax, - want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(defaultBase)), - }, - { - name: "BaseReachableTimeAndBothRandomFactors", - baseReachableTime: 2 * defaultBase, - minRandomFactor: 2 * defaultMin, - maxRandomFactor: 2 * defaultMax, - want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := stack.DefaultNUDConfigurations() - c.BaseReachableTime = defaultBase - c.MinRandomFactor = defaultMin - c.MaxRandomFactor = defaultMax - - // A fake random number generator is used to ensure deterministic - // results. - rng := fakeRand{ - num: defaultFakeRandomNum, - } - var clock faketime.NullClock - s := stack.NewNUDState(c, &clock, rand.New(&rng)) - old := s.ReachableTime() - - if got, want := s.ReachableTime(), old; got != want { - t.Errorf("got ReachableTime = %q, want = %q", got, want) - } - - // Check for recomputation when changing the min random factor, the max - // random factor, the base reachability time, or any permutation of those - // three options. - c.BaseReachableTime = test.baseReachableTime - c.MinRandomFactor = test.minRandomFactor - c.MaxRandomFactor = test.maxRandomFactor - s.SetConfig(c) - - if got, want := s.ReachableTime(), test.want; got != want { - t.Errorf("got ReachableTime = %q, want = %q", got, want) - } - - // Verify that ReachableTime isn't recomputed when none of the - // configuration options change. The random factor is changed so that if - // a recompution were to occur, ReachableTime would change. - rng.num = defaultFakeRandomNum / 2.0 - if got, want := s.ReachableTime(), test.want; got != want { - t.Errorf("got ReachableTime = %q, want = %q", got, want) - } - }) - } -} diff --git a/pkg/tcpip/stack/packet_buffer_list.go b/pkg/tcpip/stack/packet_buffer_list.go new file mode 100644 index 000000000..ce7057d6b --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_list.go @@ -0,0 +1,221 @@ +package stack + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type PacketBufferElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (PacketBufferElementMapper) linkerFor(elem *PacketBuffer) *PacketBuffer { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type PacketBufferList struct { + head *PacketBuffer + tail *PacketBuffer +} + +// Reset resets list l to the empty state. +func (l *PacketBufferList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *PacketBufferList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *PacketBufferList) Front() *PacketBuffer { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *PacketBufferList) Back() *PacketBuffer { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *PacketBufferList) Len() (count int) { + for e := l.Front(); e != nil; e = (PacketBufferElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *PacketBufferList) PushFront(e *PacketBuffer) { + linker := PacketBufferElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + PacketBufferElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *PacketBufferList) PushBack(e *PacketBuffer) { + linker := PacketBufferElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + PacketBufferElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *PacketBufferList) PushBackList(m *PacketBufferList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + PacketBufferElementMapper{}.linkerFor(l.tail).SetNext(m.head) + PacketBufferElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *PacketBufferList) InsertAfter(b, e *PacketBuffer) { + bLinker := PacketBufferElementMapper{}.linkerFor(b) + eLinker := PacketBufferElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + PacketBufferElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *PacketBufferList) InsertBefore(a, e *PacketBuffer) { + aLinker := PacketBufferElementMapper{}.linkerFor(a) + eLinker := PacketBufferElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + PacketBufferElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *PacketBufferList) Remove(e *PacketBuffer) { + linker := PacketBufferElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + PacketBufferElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + PacketBufferElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type PacketBufferEntry struct { + next *PacketBuffer + prev *PacketBuffer +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *PacketBufferEntry) Next() *PacketBuffer { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *PacketBufferEntry) Prev() *PacketBuffer { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *PacketBufferEntry) SetNext(elem *PacketBuffer) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *PacketBufferEntry) SetPrev(elem *PacketBuffer) { + e.prev = elem +} diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go deleted file mode 100644 index a8da34992..000000000 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ /dev/null @@ -1,649 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "bytes" - "fmt" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestPacketHeaderPush(t *testing.T) { - for _, test := range []struct { - name string - reserved int - link []byte - network []byte - transport []byte - data []byte - }{ - { - name: "construct empty packet", - }, - { - name: "construct link header only packet", - reserved: 60, - link: makeView(10), - }, - { - name: "construct link and network header only packet", - reserved: 60, - link: makeView(10), - network: makeView(20), - }, - { - name: "construct header only packet", - reserved: 60, - link: makeView(10), - network: makeView(20), - transport: makeView(30), - }, - { - name: "construct data only packet", - data: makeView(40), - }, - { - name: "construct L3 packet", - reserved: 60, - network: makeView(20), - transport: makeView(30), - data: makeView(40), - }, - { - name: "construct L2 packet", - reserved: 60, - link: makeView(10), - network: makeView(20), - transport: makeView(30), - data: makeView(40), - }, - } { - t.Run(test.name, func(t *testing.T) { - pk := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: test.reserved, - // Make a copy of data to make sure our truth data won't be taint by - // PacketBuffer. - Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), - }) - - allHdrSize := len(test.link) + len(test.network) + len(test.transport) - - // Check the initial values for packet. - checkInitialPacketBuffer(t, pk, PacketBufferOptions{ - ReserveHeaderBytes: test.reserved, - Data: buffer.View(test.data).ToVectorisedView(), - }) - - // Push headers. - if v := test.transport; len(v) > 0 { - copy(pk.TransportHeader().Push(len(v)), v) - } - if v := test.network; len(v) > 0 { - copy(pk.NetworkHeader().Push(len(v)), v) - } - if v := test.link; len(v) > 0 { - copy(pk.LinkHeader().Push(len(v)), v) - } - - // Check the after values for packet. - if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want { - t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want { - t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.HeaderSize(), allHdrSize; got != want { - t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) - } - if got, want := pk.Size(), allHdrSize+len(test.data); got != want { - t.Errorf("After pk.Size() = %d, want %d", got, want) - } - // Check the after state. - checkPacketContents(t, "After ", pk, packetContents{ - link: test.link, - network: test.network, - transport: test.transport, - data: test.data, - }) - }) - } -} - -func TestPacketHeaderConsume(t *testing.T) { - for _, test := range []struct { - name string - data []byte - link int - network int - transport int - }{ - { - name: "parse L2 packet", - data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)), - link: 10, - network: 20, - transport: 30, - }, - { - name: "parse L3 packet", - data: concatViews(makeView(20), makeView(30), makeView(40)), - network: 20, - transport: 30, - }, - } { - t.Run(test.name, func(t *testing.T) { - pk := NewPacketBuffer(PacketBufferOptions{ - // Make a copy of data to make sure our truth data won't be taint by - // PacketBuffer. - Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), - }) - - // Check the initial values for packet. - checkInitialPacketBuffer(t, pk, PacketBufferOptions{ - Data: buffer.View(test.data).ToVectorisedView(), - }) - - // Consume headers. - if size := test.link; size > 0 { - if _, ok := pk.LinkHeader().Consume(size); !ok { - t.Fatalf("pk.LinkHeader().Consume() = false, want true") - } - } - if size := test.network; size > 0 { - if _, ok := pk.NetworkHeader().Consume(size); !ok { - t.Fatalf("pk.NetworkHeader().Consume() = false, want true") - } - } - if size := test.transport; size > 0 { - if _, ok := pk.TransportHeader().Consume(size); !ok { - t.Fatalf("pk.TransportHeader().Consume() = false, want true") - } - } - - allHdrSize := test.link + test.network + test.transport - - // Check the after values for packet. - if got, want := pk.ReservedHeaderBytes(), 0; got != want { - t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.AvailableHeaderBytes(), 0; got != want { - t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.HeaderSize(), allHdrSize; got != want { - t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) - } - if got, want := pk.Size(), len(test.data); got != want { - t.Errorf("After pk.Size() = %d, want %d", got, want) - } - // Check the after state of pk. - checkPacketContents(t, "After ", pk, packetContents{ - link: test.data[:test.link], - network: test.data[test.link:][:test.network], - transport: test.data[test.link+test.network:][:test.transport], - data: test.data[allHdrSize:], - }) - }) - } -} - -func TestPacketHeaderConsumeDataTooShort(t *testing.T) { - data := makeView(10) - - pk := NewPacketBuffer(PacketBufferOptions{ - // Make a copy of data to make sure our truth data won't be taint by - // PacketBuffer. - Data: buffer.NewViewFromBytes(data).ToVectorisedView(), - }) - - // Consume should fail if pkt.Data is too short. - if _, ok := pk.LinkHeader().Consume(11); ok { - t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false") - } - if _, ok := pk.NetworkHeader().Consume(11); ok { - t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false") - } - if _, ok := pk.TransportHeader().Consume(11); ok { - t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false") - } - - // Check packet should look the same as initial packet. - checkInitialPacketBuffer(t, pk, PacketBufferOptions{ - Data: buffer.View(data).ToVectorisedView(), - }) -} - -// This is a very obscure use-case seen in the code that verifies packets -// before sending them out. It tries to parse the headers to verify. -// PacketHeader was initially not designed to mix Push() and Consume(), but it -// works and it's been relied upon. Include a test here. -func TestPacketHeaderPushConsumeMixed(t *testing.T) { - link := makeView(10) - network := makeView(20) - data := makeView(30) - - initData := append([]byte(nil), network...) - initData = append(initData, data...) - pk := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: len(link), - Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), - }) - - // 1. Consume network header - gotNetwork, ok := pk.NetworkHeader().Consume(len(network)) - if !ok { - t.Fatalf("pk.NetworkHeader().Consume(%d) = _, false; want _, true", len(network)) - } - checkViewEqual(t, "gotNetwork", gotNetwork, network) - - // 2. Push link header - copy(pk.LinkHeader().Push(len(link)), link) - - checkPacketContents(t, "" /* prefix */, pk, packetContents{ - link: link, - network: network, - data: data, - }) -} - -func TestPacketHeaderPushConsumeMixedTooLong(t *testing.T) { - link := makeView(10) - network := makeView(20) - data := makeView(30) - - initData := concatViews(network, data) - pk := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: len(link), - Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), - }) - - // 1. Push link header - copy(pk.LinkHeader().Push(len(link)), link) - - checkPacketContents(t, "" /* prefix */, pk, packetContents{ - link: link, - data: initData, - }) - - // 2. Consume network header, with a number of bytes too large. - gotNetwork, ok := pk.NetworkHeader().Consume(len(initData) + 1) - if ok { - t.Fatalf("pk.NetworkHeader().Consume(%d) = %q, true; want _, false", len(initData)+1, gotNetwork) - } - - checkPacketContents(t, "" /* prefix */, pk, packetContents{ - link: link, - data: initData, - }) -} - -func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { - const headerSize = 10 - - pk := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: headerSize * int(numHeaderType), - }) - - for _, h := range []PacketHeader{ - pk.TransportHeader(), - pk.NetworkHeader(), - pk.LinkHeader(), - } { - t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) { - h.Push(headerSize) - - defer func() { recover() }() - h.Push(headerSize) - t.Fatal("Second push should have panicked") - }) - } -} - -func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) { - const headerSize = 10 - - pk := NewPacketBuffer(PacketBufferOptions{ - Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), - }) - - for _, h := range []PacketHeader{ - pk.LinkHeader(), - pk.NetworkHeader(), - pk.TransportHeader(), - } { - t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) { - if _, ok := h.Consume(headerSize); !ok { - t.Fatal("First consume should succeed") - } - - defer func() { recover() }() - h.Consume(headerSize) - t.Fatal("Second consume should have panicked") - }) - } -} - -func TestPacketHeaderPushThenConsumePanics(t *testing.T) { - const headerSize = 10 - - pk := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: headerSize * int(numHeaderType), - }) - - for _, h := range []PacketHeader{ - pk.TransportHeader(), - pk.NetworkHeader(), - pk.LinkHeader(), - } { - t.Run(h.typ.String(), func(t *testing.T) { - h.Push(headerSize) - - defer func() { recover() }() - h.Consume(headerSize) - t.Fatal("Consume should have panicked") - }) - } -} - -func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { - const headerSize = 10 - - pk := NewPacketBuffer(PacketBufferOptions{ - Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), - }) - - for _, h := range []PacketHeader{ - pk.LinkHeader(), - pk.NetworkHeader(), - pk.TransportHeader(), - } { - t.Run(h.typ.String(), func(t *testing.T) { - h.Consume(headerSize) - - defer func() { recover() }() - h.Push(headerSize) - t.Fatal("Push should have panicked") - }) - } -} - -func TestPacketBufferData(t *testing.T) { - for _, tc := range []struct { - name string - makePkt func(*testing.T) *PacketBuffer - data string - }{ - { - name: "inbound packet", - makePkt: func(*testing.T) *PacketBuffer { - pkt := NewPacketBuffer(PacketBufferOptions{ - Data: vv("aabbbbccccccDATA"), - }) - pkt.LinkHeader().Consume(2) - pkt.NetworkHeader().Consume(4) - pkt.TransportHeader().Consume(6) - return pkt - }, - data: "DATA", - }, - { - name: "outbound packet", - makePkt: func(*testing.T) *PacketBuffer { - pkt := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: 12, - Data: vv("DATA"), - }) - copy(pkt.TransportHeader().Push(6), []byte("cccccc")) - copy(pkt.NetworkHeader().Push(4), []byte("bbbb")) - copy(pkt.LinkHeader().Push(2), []byte("aa")) - return pkt - }, - data: "DATA", - }, - } { - t.Run(tc.name, func(t *testing.T) { - // PullUp - for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("PullUp%d", n), func(t *testing.T) { - pkt := tc.makePkt(t) - v, ok := pkt.Data().PullUp(n) - wantV := []byte(tc.data)[:n] - if !ok || !bytes.Equal(v, wantV) { - t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want %q, true", n, v, ok, wantV) - } - }) - } - t.Run("PullUpOutOfBounds", func(t *testing.T) { - n := len(tc.data) + 1 - pkt := tc.makePkt(t) - v, ok := pkt.Data().PullUp(n) - if ok || v != nil { - t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want nil, false", n, v, ok) - } - }) - - // DeleteFront - for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) { - pkt := tc.makePkt(t) - pkt.Data().DeleteFront(n) - - checkData(t, pkt, []byte(tc.data)[n:]) - }) - } - - // CapLength - for _, n := range []int{0, 1, len(tc.data)} { - t.Run(fmt.Sprintf("CapLength%d", n), func(t *testing.T) { - pkt := tc.makePkt(t) - pkt.Data().CapLength(n) - - want := []byte(tc.data) - if n < len(want) { - want = want[:n] - } - checkData(t, pkt, want) - }) - } - - // Views - t.Run("Views", func(t *testing.T) { - pkt := tc.makePkt(t) - checkData(t, pkt, []byte(tc.data)) - }) - - // AppendView - t.Run("AppendView", func(t *testing.T) { - s := "APPEND" - - pkt := tc.makePkt(t) - pkt.Data().AppendView(buffer.View(s)) - - checkData(t, pkt, []byte(tc.data+s)) - }) - - // ReadFromVV - for _, n := range []int{0, 1, 2, 7, 10, 14, 20} { - t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) { - s := "TO READ" - srcVV := vv(s, s) - s += s - - pkt := tc.makePkt(t) - pkt.Data().ReadFromVV(&srcVV, n) - - if n < len(s) { - s = s[:n] - } - checkData(t, pkt, []byte(tc.data+s)) - }) - } - - // ExtractVV - t.Run("ExtractVV", func(t *testing.T) { - pkt := tc.makePkt(t) - extractedVV := pkt.Data().ExtractVV() - - got := extractedVV.ToOwnedView() - want := []byte(tc.data) - if !bytes.Equal(got, want) { - t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want) - } - }) - }) - } -} - -type packetContents struct { - link buffer.View - network buffer.View - transport buffer.View - data buffer.View -} - -func checkPacketContents(t *testing.T, prefix string, pk *PacketBuffer, want packetContents) { - t.Helper() - // Headers. - checkPacketHeader(t, prefix+"pk.LinkHeader", pk.LinkHeader(), want.link) - checkPacketHeader(t, prefix+"pk.NetworkHeader", pk.NetworkHeader(), want.network) - checkPacketHeader(t, prefix+"pk.TransportHeader", pk.TransportHeader(), want.transport) - // Data. - checkData(t, pk, want.data) - // Whole packet. - checkViewEqual(t, prefix+"pk.Views()", - concatViews(pk.Views()...), - concatViews(want.link, want.network, want.transport, want.data)) - // PayloadSince. - checkViewEqual(t, prefix+"PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(want.link, want.network, want.transport, want.data)) - checkViewEqual(t, prefix+"PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(want.network, want.transport, want.data)) - checkViewEqual(t, prefix+"PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(want.transport, want.data)) -} - -func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { - t.Helper() - reserved := opts.ReserveHeaderBytes - if got, want := pk.ReservedHeaderBytes(), reserved; got != want { - t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.AvailableHeaderBytes(), reserved; got != want { - t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want) - } - if got, want := pk.HeaderSize(), 0; got != want { - t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want) - } - data := opts.Data.ToView() - if got, want := pk.Size(), len(data); got != want { - t.Errorf("Initial pk.Size() = %d, want %d", got, want) - } - checkPacketContents(t, "Initial ", pk, packetContents{ - data: data, - }) -} - -func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { - t.Helper() - checkViewEqual(t, name+".View()", h.View(), want) -} - -func checkViewEqual(t *testing.T, what string, got, want buffer.View) { - t.Helper() - if !bytes.Equal(got, want) { - t.Errorf("%s = %x, want %x", what, got, want) - } -} - -func checkData(t *testing.T, pkt *PacketBuffer, want []byte) { - t.Helper() - if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) { - t.Errorf("pkt.Data().Views() = 0x%x, want 0x%x", got, want) - } - if got := pkt.Data().Size(); got != len(want) { - t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want)) - } - - t.Run("AsRange", func(t *testing.T) { - // Full range - checkRange(t, pkt.Data().AsRange(), want) - - // SubRange - for _, off := range []int{0, 1, len(want), len(want) + 1} { - t.Run(fmt.Sprintf("SubRange%d", off), func(t *testing.T) { - // Empty when off is greater than the size of range. - var sub []byte - if off < len(want) { - sub = want[off:] - } - checkRange(t, pkt.Data().AsRange().SubRange(off), sub) - }) - } - - // Capped - for _, n := range []int{0, 1, len(want), len(want) + 1} { - t.Run(fmt.Sprintf("Capped%d", n), func(t *testing.T) { - sub := want - if n < len(sub) { - sub = sub[:n] - } - checkRange(t, pkt.Data().AsRange().Capped(n), sub) - }) - } - }) -} - -func checkRange(t *testing.T, r Range, data []byte) { - if got, want := r.Size(), len(data); got != want { - t.Errorf("r.Size() = %d, want %d", got, want) - } - if got := r.AsView(); !bytes.Equal(got, data) { - t.Errorf("r.AsView() = %x, want %x", got, data) - } - if got := r.ToOwnedView(); !bytes.Equal(got, data) { - t.Errorf("r.ToOwnedView() = %x, want %x", got, data) - } - if got, want := r.Checksum(), header.Checksum(data, 0 /* initial */); got != want { - t.Errorf("r.Checksum() = %x, want %x", got, want) - } -} - -func vv(pieces ...string) buffer.VectorisedView { - var views []buffer.View - var size int - for _, p := range pieces { - v := buffer.View([]byte(p)) - size += len(v) - views = append(views, v) - } - return buffer.NewVectorisedView(size, views) -} - -func makeView(size int) buffer.View { - b := byte(size) - return bytes.Repeat([]byte{b}, size) -} - -func concatViews(views ...buffer.View) buffer.View { - var all buffer.View - for _, v := range views { - all = append(all, v...) - } - return all -} diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go new file mode 100644 index 000000000..b287f64c3 --- /dev/null +++ b/pkg/tcpip/stack/stack_state_autogen.go @@ -0,0 +1,1287 @@ +// automatically generated by stateify. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +func (t *tuple) StateTypeName() string { + return "pkg/tcpip/stack.tuple" +} + +func (t *tuple) StateFields() []string { + return []string{ + "tupleEntry", + "tupleID", + "conn", + "direction", + } +} + +func (t *tuple) beforeSave() {} + +// +checklocksignore +func (t *tuple) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.tupleEntry) + stateSinkObject.Save(1, &t.tupleID) + stateSinkObject.Save(2, &t.conn) + stateSinkObject.Save(3, &t.direction) +} + +func (t *tuple) afterLoad() {} + +// +checklocksignore +func (t *tuple) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.tupleEntry) + stateSourceObject.Load(1, &t.tupleID) + stateSourceObject.Load(2, &t.conn) + stateSourceObject.Load(3, &t.direction) +} + +func (ti *tupleID) StateTypeName() string { + return "pkg/tcpip/stack.tupleID" +} + +func (ti *tupleID) StateFields() []string { + return []string{ + "srcAddr", + "srcPort", + "dstAddr", + "dstPort", + "transProto", + "netProto", + } +} + +func (ti *tupleID) beforeSave() {} + +// +checklocksignore +func (ti *tupleID) StateSave(stateSinkObject state.Sink) { + ti.beforeSave() + stateSinkObject.Save(0, &ti.srcAddr) + stateSinkObject.Save(1, &ti.srcPort) + stateSinkObject.Save(2, &ti.dstAddr) + stateSinkObject.Save(3, &ti.dstPort) + stateSinkObject.Save(4, &ti.transProto) + stateSinkObject.Save(5, &ti.netProto) +} + +func (ti *tupleID) afterLoad() {} + +// +checklocksignore +func (ti *tupleID) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ti.srcAddr) + stateSourceObject.Load(1, &ti.srcPort) + stateSourceObject.Load(2, &ti.dstAddr) + stateSourceObject.Load(3, &ti.dstPort) + stateSourceObject.Load(4, &ti.transProto) + stateSourceObject.Load(5, &ti.netProto) +} + +func (cn *conn) StateTypeName() string { + return "pkg/tcpip/stack.conn" +} + +func (cn *conn) StateFields() []string { + return []string{ + "original", + "reply", + "manip", + "tcbHook", + "tcb", + "lastUsed", + } +} + +func (cn *conn) beforeSave() {} + +// +checklocksignore +func (cn *conn) StateSave(stateSinkObject state.Sink) { + cn.beforeSave() + var lastUsedValue unixTime = cn.saveLastUsed() + stateSinkObject.SaveValue(5, lastUsedValue) + stateSinkObject.Save(0, &cn.original) + stateSinkObject.Save(1, &cn.reply) + stateSinkObject.Save(2, &cn.manip) + stateSinkObject.Save(3, &cn.tcbHook) + stateSinkObject.Save(4, &cn.tcb) +} + +func (cn *conn) afterLoad() {} + +// +checklocksignore +func (cn *conn) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &cn.original) + stateSourceObject.Load(1, &cn.reply) + stateSourceObject.Load(2, &cn.manip) + stateSourceObject.Load(3, &cn.tcbHook) + stateSourceObject.Load(4, &cn.tcb) + stateSourceObject.LoadValue(5, new(unixTime), func(y interface{}) { cn.loadLastUsed(y.(unixTime)) }) +} + +func (ct *ConnTrack) StateTypeName() string { + return "pkg/tcpip/stack.ConnTrack" +} + +func (ct *ConnTrack) StateFields() []string { + return []string{ + "seed", + "buckets", + } +} + +// +checklocksignore +func (ct *ConnTrack) StateSave(stateSinkObject state.Sink) { + ct.beforeSave() + stateSinkObject.Save(0, &ct.seed) + stateSinkObject.Save(1, &ct.buckets) +} + +func (ct *ConnTrack) afterLoad() {} + +// +checklocksignore +func (ct *ConnTrack) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ct.seed) + stateSourceObject.Load(1, &ct.buckets) +} + +func (b *bucket) StateTypeName() string { + return "pkg/tcpip/stack.bucket" +} + +func (b *bucket) StateFields() []string { + return []string{ + "tuples", + } +} + +func (b *bucket) beforeSave() {} + +// +checklocksignore +func (b *bucket) StateSave(stateSinkObject state.Sink) { + b.beforeSave() + stateSinkObject.Save(0, &b.tuples) +} + +func (b *bucket) afterLoad() {} + +// +checklocksignore +func (b *bucket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &b.tuples) +} + +func (u *unixTime) StateTypeName() string { + return "pkg/tcpip/stack.unixTime" +} + +func (u *unixTime) StateFields() []string { + return []string{ + "second", + "nano", + } +} + +func (u *unixTime) beforeSave() {} + +// +checklocksignore +func (u *unixTime) StateSave(stateSinkObject state.Sink) { + u.beforeSave() + stateSinkObject.Save(0, &u.second) + stateSinkObject.Save(1, &u.nano) +} + +func (u *unixTime) afterLoad() {} + +// +checklocksignore +func (u *unixTime) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &u.second) + stateSourceObject.Load(1, &u.nano) +} + +func (it *IPTables) StateTypeName() string { + return "pkg/tcpip/stack.IPTables" +} + +func (it *IPTables) StateFields() []string { + return []string{ + "mu", + "v4Tables", + "v6Tables", + "modified", + "priorities", + "connections", + "reaperDone", + } +} + +// +checklocksignore +func (it *IPTables) StateSave(stateSinkObject state.Sink) { + it.beforeSave() + stateSinkObject.Save(0, &it.mu) + stateSinkObject.Save(1, &it.v4Tables) + stateSinkObject.Save(2, &it.v6Tables) + stateSinkObject.Save(3, &it.modified) + stateSinkObject.Save(4, &it.priorities) + stateSinkObject.Save(5, &it.connections) + stateSinkObject.Save(6, &it.reaperDone) +} + +// +checklocksignore +func (it *IPTables) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &it.mu) + stateSourceObject.Load(1, &it.v4Tables) + stateSourceObject.Load(2, &it.v6Tables) + stateSourceObject.Load(3, &it.modified) + stateSourceObject.Load(4, &it.priorities) + stateSourceObject.Load(5, &it.connections) + stateSourceObject.Load(6, &it.reaperDone) + stateSourceObject.AfterLoad(it.afterLoad) +} + +func (table *Table) StateTypeName() string { + return "pkg/tcpip/stack.Table" +} + +func (table *Table) StateFields() []string { + return []string{ + "Rules", + "BuiltinChains", + "Underflows", + } +} + +func (table *Table) beforeSave() {} + +// +checklocksignore +func (table *Table) StateSave(stateSinkObject state.Sink) { + table.beforeSave() + stateSinkObject.Save(0, &table.Rules) + stateSinkObject.Save(1, &table.BuiltinChains) + stateSinkObject.Save(2, &table.Underflows) +} + +func (table *Table) afterLoad() {} + +// +checklocksignore +func (table *Table) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &table.Rules) + stateSourceObject.Load(1, &table.BuiltinChains) + stateSourceObject.Load(2, &table.Underflows) +} + +func (r *Rule) StateTypeName() string { + return "pkg/tcpip/stack.Rule" +} + +func (r *Rule) StateFields() []string { + return []string{ + "Filter", + "Matchers", + "Target", + } +} + +func (r *Rule) beforeSave() {} + +// +checklocksignore +func (r *Rule) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.Filter) + stateSinkObject.Save(1, &r.Matchers) + stateSinkObject.Save(2, &r.Target) +} + +func (r *Rule) afterLoad() {} + +// +checklocksignore +func (r *Rule) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.Filter) + stateSourceObject.Load(1, &r.Matchers) + stateSourceObject.Load(2, &r.Target) +} + +func (fl *IPHeaderFilter) StateTypeName() string { + return "pkg/tcpip/stack.IPHeaderFilter" +} + +func (fl *IPHeaderFilter) StateFields() []string { + return []string{ + "Protocol", + "CheckProtocol", + "Dst", + "DstMask", + "DstInvert", + "Src", + "SrcMask", + "SrcInvert", + "InputInterface", + "InputInterfaceMask", + "InputInterfaceInvert", + "OutputInterface", + "OutputInterfaceMask", + "OutputInterfaceInvert", + } +} + +func (fl *IPHeaderFilter) beforeSave() {} + +// +checklocksignore +func (fl *IPHeaderFilter) StateSave(stateSinkObject state.Sink) { + fl.beforeSave() + stateSinkObject.Save(0, &fl.Protocol) + stateSinkObject.Save(1, &fl.CheckProtocol) + stateSinkObject.Save(2, &fl.Dst) + stateSinkObject.Save(3, &fl.DstMask) + stateSinkObject.Save(4, &fl.DstInvert) + stateSinkObject.Save(5, &fl.Src) + stateSinkObject.Save(6, &fl.SrcMask) + stateSinkObject.Save(7, &fl.SrcInvert) + stateSinkObject.Save(8, &fl.InputInterface) + stateSinkObject.Save(9, &fl.InputInterfaceMask) + stateSinkObject.Save(10, &fl.InputInterfaceInvert) + stateSinkObject.Save(11, &fl.OutputInterface) + stateSinkObject.Save(12, &fl.OutputInterfaceMask) + stateSinkObject.Save(13, &fl.OutputInterfaceInvert) +} + +func (fl *IPHeaderFilter) afterLoad() {} + +// +checklocksignore +func (fl *IPHeaderFilter) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &fl.Protocol) + stateSourceObject.Load(1, &fl.CheckProtocol) + stateSourceObject.Load(2, &fl.Dst) + stateSourceObject.Load(3, &fl.DstMask) + stateSourceObject.Load(4, &fl.DstInvert) + stateSourceObject.Load(5, &fl.Src) + stateSourceObject.Load(6, &fl.SrcMask) + stateSourceObject.Load(7, &fl.SrcInvert) + stateSourceObject.Load(8, &fl.InputInterface) + stateSourceObject.Load(9, &fl.InputInterfaceMask) + stateSourceObject.Load(10, &fl.InputInterfaceInvert) + stateSourceObject.Load(11, &fl.OutputInterface) + stateSourceObject.Load(12, &fl.OutputInterfaceMask) + stateSourceObject.Load(13, &fl.OutputInterfaceInvert) +} + +func (l *neighborEntryList) StateTypeName() string { + return "pkg/tcpip/stack.neighborEntryList" +} + +func (l *neighborEntryList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *neighborEntryList) beforeSave() {} + +// +checklocksignore +func (l *neighborEntryList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *neighborEntryList) afterLoad() {} + +// +checklocksignore +func (l *neighborEntryList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *neighborEntryEntry) StateTypeName() string { + return "pkg/tcpip/stack.neighborEntryEntry" +} + +func (e *neighborEntryEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *neighborEntryEntry) beforeSave() {} + +// +checklocksignore +func (e *neighborEntryEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *neighborEntryEntry) afterLoad() {} + +// +checklocksignore +func (e *neighborEntryEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (p *PacketBufferList) StateTypeName() string { + return "pkg/tcpip/stack.PacketBufferList" +} + +func (p *PacketBufferList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (p *PacketBufferList) beforeSave() {} + +// +checklocksignore +func (p *PacketBufferList) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + stateSinkObject.Save(0, &p.head) + stateSinkObject.Save(1, &p.tail) +} + +func (p *PacketBufferList) afterLoad() {} + +// +checklocksignore +func (p *PacketBufferList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.head) + stateSourceObject.Load(1, &p.tail) +} + +func (e *PacketBufferEntry) StateTypeName() string { + return "pkg/tcpip/stack.PacketBufferEntry" +} + +func (e *PacketBufferEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *PacketBufferEntry) beforeSave() {} + +// +checklocksignore +func (e *PacketBufferEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *PacketBufferEntry) afterLoad() {} + +// +checklocksignore +func (e *PacketBufferEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (t *TransportEndpointID) StateTypeName() string { + return "pkg/tcpip/stack.TransportEndpointID" +} + +func (t *TransportEndpointID) StateFields() []string { + return []string{ + "LocalPort", + "LocalAddress", + "RemotePort", + "RemoteAddress", + } +} + +func (t *TransportEndpointID) beforeSave() {} + +// +checklocksignore +func (t *TransportEndpointID) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.LocalPort) + stateSinkObject.Save(1, &t.LocalAddress) + stateSinkObject.Save(2, &t.RemotePort) + stateSinkObject.Save(3, &t.RemoteAddress) +} + +func (t *TransportEndpointID) afterLoad() {} + +// +checklocksignore +func (t *TransportEndpointID) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.LocalPort) + stateSourceObject.Load(1, &t.LocalAddress) + stateSourceObject.Load(2, &t.RemotePort) + stateSourceObject.Load(3, &t.RemoteAddress) +} + +func (g *GSOType) StateTypeName() string { + return "pkg/tcpip/stack.GSOType" +} + +func (g *GSOType) StateFields() []string { + return nil +} + +func (g *GSO) StateTypeName() string { + return "pkg/tcpip/stack.GSO" +} + +func (g *GSO) StateFields() []string { + return []string{ + "Type", + "NeedsCsum", + "CsumOffset", + "MSS", + "L3HdrLen", + "MaxSize", + } +} + +func (g *GSO) beforeSave() {} + +// +checklocksignore +func (g *GSO) StateSave(stateSinkObject state.Sink) { + g.beforeSave() + stateSinkObject.Save(0, &g.Type) + stateSinkObject.Save(1, &g.NeedsCsum) + stateSinkObject.Save(2, &g.CsumOffset) + stateSinkObject.Save(3, &g.MSS) + stateSinkObject.Save(4, &g.L3HdrLen) + stateSinkObject.Save(5, &g.MaxSize) +} + +func (g *GSO) afterLoad() {} + +// +checklocksignore +func (g *GSO) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &g.Type) + stateSourceObject.Load(1, &g.NeedsCsum) + stateSourceObject.Load(2, &g.CsumOffset) + stateSourceObject.Load(3, &g.MSS) + stateSourceObject.Load(4, &g.L3HdrLen) + stateSourceObject.Load(5, &g.MaxSize) +} + +func (t *TransportEndpointInfo) StateTypeName() string { + return "pkg/tcpip/stack.TransportEndpointInfo" +} + +func (t *TransportEndpointInfo) StateFields() []string { + return []string{ + "NetProto", + "TransProto", + "ID", + "BindNICID", + "BindAddr", + "RegisterNICID", + } +} + +func (t *TransportEndpointInfo) beforeSave() {} + +// +checklocksignore +func (t *TransportEndpointInfo) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.NetProto) + stateSinkObject.Save(1, &t.TransProto) + stateSinkObject.Save(2, &t.ID) + stateSinkObject.Save(3, &t.BindNICID) + stateSinkObject.Save(4, &t.BindAddr) + stateSinkObject.Save(5, &t.RegisterNICID) +} + +func (t *TransportEndpointInfo) afterLoad() {} + +// +checklocksignore +func (t *TransportEndpointInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.NetProto) + stateSourceObject.Load(1, &t.TransProto) + stateSourceObject.Load(2, &t.ID) + stateSourceObject.Load(3, &t.BindNICID) + stateSourceObject.Load(4, &t.BindAddr) + stateSourceObject.Load(5, &t.RegisterNICID) +} + +func (t *TCPCubicState) StateTypeName() string { + return "pkg/tcpip/stack.TCPCubicState" +} + +func (t *TCPCubicState) StateFields() []string { + return []string{ + "WLastMax", + "WMax", + "T", + "TimeSinceLastCongestion", + "C", + "K", + "Beta", + "WC", + "WEst", + } +} + +func (t *TCPCubicState) beforeSave() {} + +// +checklocksignore +func (t *TCPCubicState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.WLastMax) + stateSinkObject.Save(1, &t.WMax) + stateSinkObject.Save(2, &t.T) + stateSinkObject.Save(3, &t.TimeSinceLastCongestion) + stateSinkObject.Save(4, &t.C) + stateSinkObject.Save(5, &t.K) + stateSinkObject.Save(6, &t.Beta) + stateSinkObject.Save(7, &t.WC) + stateSinkObject.Save(8, &t.WEst) +} + +func (t *TCPCubicState) afterLoad() {} + +// +checklocksignore +func (t *TCPCubicState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.WLastMax) + stateSourceObject.Load(1, &t.WMax) + stateSourceObject.Load(2, &t.T) + stateSourceObject.Load(3, &t.TimeSinceLastCongestion) + stateSourceObject.Load(4, &t.C) + stateSourceObject.Load(5, &t.K) + stateSourceObject.Load(6, &t.Beta) + stateSourceObject.Load(7, &t.WC) + stateSourceObject.Load(8, &t.WEst) +} + +func (t *TCPRACKState) StateTypeName() string { + return "pkg/tcpip/stack.TCPRACKState" +} + +func (t *TCPRACKState) StateFields() []string { + return []string{ + "XmitTime", + "EndSequence", + "FACK", + "RTT", + "Reord", + "DSACKSeen", + "ReoWnd", + "ReoWndIncr", + "ReoWndPersist", + "RTTSeq", + } +} + +func (t *TCPRACKState) beforeSave() {} + +// +checklocksignore +func (t *TCPRACKState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.XmitTime) + stateSinkObject.Save(1, &t.EndSequence) + stateSinkObject.Save(2, &t.FACK) + stateSinkObject.Save(3, &t.RTT) + stateSinkObject.Save(4, &t.Reord) + stateSinkObject.Save(5, &t.DSACKSeen) + stateSinkObject.Save(6, &t.ReoWnd) + stateSinkObject.Save(7, &t.ReoWndIncr) + stateSinkObject.Save(8, &t.ReoWndPersist) + stateSinkObject.Save(9, &t.RTTSeq) +} + +func (t *TCPRACKState) afterLoad() {} + +// +checklocksignore +func (t *TCPRACKState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.XmitTime) + stateSourceObject.Load(1, &t.EndSequence) + stateSourceObject.Load(2, &t.FACK) + stateSourceObject.Load(3, &t.RTT) + stateSourceObject.Load(4, &t.Reord) + stateSourceObject.Load(5, &t.DSACKSeen) + stateSourceObject.Load(6, &t.ReoWnd) + stateSourceObject.Load(7, &t.ReoWndIncr) + stateSourceObject.Load(8, &t.ReoWndPersist) + stateSourceObject.Load(9, &t.RTTSeq) +} + +func (t *TCPEndpointID) StateTypeName() string { + return "pkg/tcpip/stack.TCPEndpointID" +} + +func (t *TCPEndpointID) StateFields() []string { + return []string{ + "LocalPort", + "LocalAddress", + "RemotePort", + "RemoteAddress", + } +} + +func (t *TCPEndpointID) beforeSave() {} + +// +checklocksignore +func (t *TCPEndpointID) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.LocalPort) + stateSinkObject.Save(1, &t.LocalAddress) + stateSinkObject.Save(2, &t.RemotePort) + stateSinkObject.Save(3, &t.RemoteAddress) +} + +func (t *TCPEndpointID) afterLoad() {} + +// +checklocksignore +func (t *TCPEndpointID) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.LocalPort) + stateSourceObject.Load(1, &t.LocalAddress) + stateSourceObject.Load(2, &t.RemotePort) + stateSourceObject.Load(3, &t.RemoteAddress) +} + +func (t *TCPFastRecoveryState) StateTypeName() string { + return "pkg/tcpip/stack.TCPFastRecoveryState" +} + +func (t *TCPFastRecoveryState) StateFields() []string { + return []string{ + "Active", + "First", + "Last", + "MaxCwnd", + "HighRxt", + "RescueRxt", + } +} + +func (t *TCPFastRecoveryState) beforeSave() {} + +// +checklocksignore +func (t *TCPFastRecoveryState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.Active) + stateSinkObject.Save(1, &t.First) + stateSinkObject.Save(2, &t.Last) + stateSinkObject.Save(3, &t.MaxCwnd) + stateSinkObject.Save(4, &t.HighRxt) + stateSinkObject.Save(5, &t.RescueRxt) +} + +func (t *TCPFastRecoveryState) afterLoad() {} + +// +checklocksignore +func (t *TCPFastRecoveryState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.Active) + stateSourceObject.Load(1, &t.First) + stateSourceObject.Load(2, &t.Last) + stateSourceObject.Load(3, &t.MaxCwnd) + stateSourceObject.Load(4, &t.HighRxt) + stateSourceObject.Load(5, &t.RescueRxt) +} + +func (t *TCPReceiverState) StateTypeName() string { + return "pkg/tcpip/stack.TCPReceiverState" +} + +func (t *TCPReceiverState) StateFields() []string { + return []string{ + "RcvNxt", + "RcvAcc", + "RcvWndScale", + "PendingBufUsed", + } +} + +func (t *TCPReceiverState) beforeSave() {} + +// +checklocksignore +func (t *TCPReceiverState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.RcvNxt) + stateSinkObject.Save(1, &t.RcvAcc) + stateSinkObject.Save(2, &t.RcvWndScale) + stateSinkObject.Save(3, &t.PendingBufUsed) +} + +func (t *TCPReceiverState) afterLoad() {} + +// +checklocksignore +func (t *TCPReceiverState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.RcvNxt) + stateSourceObject.Load(1, &t.RcvAcc) + stateSourceObject.Load(2, &t.RcvWndScale) + stateSourceObject.Load(3, &t.PendingBufUsed) +} + +func (t *TCPRTTState) StateTypeName() string { + return "pkg/tcpip/stack.TCPRTTState" +} + +func (t *TCPRTTState) StateFields() []string { + return []string{ + "SRTT", + "RTTVar", + "SRTTInited", + } +} + +func (t *TCPRTTState) beforeSave() {} + +// +checklocksignore +func (t *TCPRTTState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.SRTT) + stateSinkObject.Save(1, &t.RTTVar) + stateSinkObject.Save(2, &t.SRTTInited) +} + +func (t *TCPRTTState) afterLoad() {} + +// +checklocksignore +func (t *TCPRTTState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.SRTT) + stateSourceObject.Load(1, &t.RTTVar) + stateSourceObject.Load(2, &t.SRTTInited) +} + +func (t *TCPSenderState) StateTypeName() string { + return "pkg/tcpip/stack.TCPSenderState" +} + +func (t *TCPSenderState) StateFields() []string { + return []string{ + "LastSendTime", + "DupAckCount", + "SndCwnd", + "Ssthresh", + "SndCAAckCount", + "Outstanding", + "SackedOut", + "SndWnd", + "SndUna", + "SndNxt", + "RTTMeasureSeqNum", + "RTTMeasureTime", + "Closed", + "RTO", + "RTTState", + "MaxPayloadSize", + "SndWndScale", + "MaxSentAck", + "FastRecovery", + "Cubic", + "RACKState", + } +} + +func (t *TCPSenderState) beforeSave() {} + +// +checklocksignore +func (t *TCPSenderState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.LastSendTime) + stateSinkObject.Save(1, &t.DupAckCount) + stateSinkObject.Save(2, &t.SndCwnd) + stateSinkObject.Save(3, &t.Ssthresh) + stateSinkObject.Save(4, &t.SndCAAckCount) + stateSinkObject.Save(5, &t.Outstanding) + stateSinkObject.Save(6, &t.SackedOut) + stateSinkObject.Save(7, &t.SndWnd) + stateSinkObject.Save(8, &t.SndUna) + stateSinkObject.Save(9, &t.SndNxt) + stateSinkObject.Save(10, &t.RTTMeasureSeqNum) + stateSinkObject.Save(11, &t.RTTMeasureTime) + stateSinkObject.Save(12, &t.Closed) + stateSinkObject.Save(13, &t.RTO) + stateSinkObject.Save(14, &t.RTTState) + stateSinkObject.Save(15, &t.MaxPayloadSize) + stateSinkObject.Save(16, &t.SndWndScale) + stateSinkObject.Save(17, &t.MaxSentAck) + stateSinkObject.Save(18, &t.FastRecovery) + stateSinkObject.Save(19, &t.Cubic) + stateSinkObject.Save(20, &t.RACKState) +} + +func (t *TCPSenderState) afterLoad() {} + +// +checklocksignore +func (t *TCPSenderState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.LastSendTime) + stateSourceObject.Load(1, &t.DupAckCount) + stateSourceObject.Load(2, &t.SndCwnd) + stateSourceObject.Load(3, &t.Ssthresh) + stateSourceObject.Load(4, &t.SndCAAckCount) + stateSourceObject.Load(5, &t.Outstanding) + stateSourceObject.Load(6, &t.SackedOut) + stateSourceObject.Load(7, &t.SndWnd) + stateSourceObject.Load(8, &t.SndUna) + stateSourceObject.Load(9, &t.SndNxt) + stateSourceObject.Load(10, &t.RTTMeasureSeqNum) + stateSourceObject.Load(11, &t.RTTMeasureTime) + stateSourceObject.Load(12, &t.Closed) + stateSourceObject.Load(13, &t.RTO) + stateSourceObject.Load(14, &t.RTTState) + stateSourceObject.Load(15, &t.MaxPayloadSize) + stateSourceObject.Load(16, &t.SndWndScale) + stateSourceObject.Load(17, &t.MaxSentAck) + stateSourceObject.Load(18, &t.FastRecovery) + stateSourceObject.Load(19, &t.Cubic) + stateSourceObject.Load(20, &t.RACKState) +} + +func (t *TCPSACKInfo) StateTypeName() string { + return "pkg/tcpip/stack.TCPSACKInfo" +} + +func (t *TCPSACKInfo) StateFields() []string { + return []string{ + "Blocks", + "ReceivedBlocks", + "MaxSACKED", + } +} + +func (t *TCPSACKInfo) beforeSave() {} + +// +checklocksignore +func (t *TCPSACKInfo) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.Blocks) + stateSinkObject.Save(1, &t.ReceivedBlocks) + stateSinkObject.Save(2, &t.MaxSACKED) +} + +func (t *TCPSACKInfo) afterLoad() {} + +// +checklocksignore +func (t *TCPSACKInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.Blocks) + stateSourceObject.Load(1, &t.ReceivedBlocks) + stateSourceObject.Load(2, &t.MaxSACKED) +} + +func (r *RcvBufAutoTuneParams) StateTypeName() string { + return "pkg/tcpip/stack.RcvBufAutoTuneParams" +} + +func (r *RcvBufAutoTuneParams) StateFields() []string { + return []string{ + "MeasureTime", + "CopiedBytes", + "PrevCopiedBytes", + "RcvBufSize", + "RTT", + "RTTVar", + "RTTMeasureSeqNumber", + "RTTMeasureTime", + "Disabled", + } +} + +func (r *RcvBufAutoTuneParams) beforeSave() {} + +// +checklocksignore +func (r *RcvBufAutoTuneParams) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.MeasureTime) + stateSinkObject.Save(1, &r.CopiedBytes) + stateSinkObject.Save(2, &r.PrevCopiedBytes) + stateSinkObject.Save(3, &r.RcvBufSize) + stateSinkObject.Save(4, &r.RTT) + stateSinkObject.Save(5, &r.RTTVar) + stateSinkObject.Save(6, &r.RTTMeasureSeqNumber) + stateSinkObject.Save(7, &r.RTTMeasureTime) + stateSinkObject.Save(8, &r.Disabled) +} + +func (r *RcvBufAutoTuneParams) afterLoad() {} + +// +checklocksignore +func (r *RcvBufAutoTuneParams) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.MeasureTime) + stateSourceObject.Load(1, &r.CopiedBytes) + stateSourceObject.Load(2, &r.PrevCopiedBytes) + stateSourceObject.Load(3, &r.RcvBufSize) + stateSourceObject.Load(4, &r.RTT) + stateSourceObject.Load(5, &r.RTTVar) + stateSourceObject.Load(6, &r.RTTMeasureSeqNumber) + stateSourceObject.Load(7, &r.RTTMeasureTime) + stateSourceObject.Load(8, &r.Disabled) +} + +func (t *TCPRcvBufState) StateTypeName() string { + return "pkg/tcpip/stack.TCPRcvBufState" +} + +func (t *TCPRcvBufState) StateFields() []string { + return []string{ + "RcvBufUsed", + "RcvAutoParams", + "RcvClosed", + } +} + +func (t *TCPRcvBufState) beforeSave() {} + +// +checklocksignore +func (t *TCPRcvBufState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.RcvBufUsed) + stateSinkObject.Save(1, &t.RcvAutoParams) + stateSinkObject.Save(2, &t.RcvClosed) +} + +func (t *TCPRcvBufState) afterLoad() {} + +// +checklocksignore +func (t *TCPRcvBufState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.RcvBufUsed) + stateSourceObject.Load(1, &t.RcvAutoParams) + stateSourceObject.Load(2, &t.RcvClosed) +} + +func (t *TCPSndBufState) StateTypeName() string { + return "pkg/tcpip/stack.TCPSndBufState" +} + +func (t *TCPSndBufState) StateFields() []string { + return []string{ + "SndBufSize", + "SndBufUsed", + "SndClosed", + "SndBufInQueue", + "PacketTooBigCount", + "SndMTU", + } +} + +func (t *TCPSndBufState) beforeSave() {} + +// +checklocksignore +func (t *TCPSndBufState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.SndBufSize) + stateSinkObject.Save(1, &t.SndBufUsed) + stateSinkObject.Save(2, &t.SndClosed) + stateSinkObject.Save(3, &t.SndBufInQueue) + stateSinkObject.Save(4, &t.PacketTooBigCount) + stateSinkObject.Save(5, &t.SndMTU) +} + +func (t *TCPSndBufState) afterLoad() {} + +// +checklocksignore +func (t *TCPSndBufState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.SndBufSize) + stateSourceObject.Load(1, &t.SndBufUsed) + stateSourceObject.Load(2, &t.SndClosed) + stateSourceObject.Load(3, &t.SndBufInQueue) + stateSourceObject.Load(4, &t.PacketTooBigCount) + stateSourceObject.Load(5, &t.SndMTU) +} + +func (t *TCPEndpointStateInner) StateTypeName() string { + return "pkg/tcpip/stack.TCPEndpointStateInner" +} + +func (t *TCPEndpointStateInner) StateFields() []string { + return []string{ + "TSOffset", + "SACKPermitted", + "SendTSOk", + "RecentTS", + } +} + +func (t *TCPEndpointStateInner) beforeSave() {} + +// +checklocksignore +func (t *TCPEndpointStateInner) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.TSOffset) + stateSinkObject.Save(1, &t.SACKPermitted) + stateSinkObject.Save(2, &t.SendTSOk) + stateSinkObject.Save(3, &t.RecentTS) +} + +func (t *TCPEndpointStateInner) afterLoad() {} + +// +checklocksignore +func (t *TCPEndpointStateInner) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.TSOffset) + stateSourceObject.Load(1, &t.SACKPermitted) + stateSourceObject.Load(2, &t.SendTSOk) + stateSourceObject.Load(3, &t.RecentTS) +} + +func (t *TCPEndpointState) StateTypeName() string { + return "pkg/tcpip/stack.TCPEndpointState" +} + +func (t *TCPEndpointState) StateFields() []string { + return []string{ + "TCPEndpointStateInner", + "ID", + "SegTime", + "RcvBufState", + "SndBufState", + "SACK", + "Receiver", + "Sender", + } +} + +func (t *TCPEndpointState) beforeSave() {} + +// +checklocksignore +func (t *TCPEndpointState) StateSave(stateSinkObject state.Sink) { + t.beforeSave() + stateSinkObject.Save(0, &t.TCPEndpointStateInner) + stateSinkObject.Save(1, &t.ID) + stateSinkObject.Save(2, &t.SegTime) + stateSinkObject.Save(3, &t.RcvBufState) + stateSinkObject.Save(4, &t.SndBufState) + stateSinkObject.Save(5, &t.SACK) + stateSinkObject.Save(6, &t.Receiver) + stateSinkObject.Save(7, &t.Sender) +} + +func (t *TCPEndpointState) afterLoad() {} + +// +checklocksignore +func (t *TCPEndpointState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &t.TCPEndpointStateInner) + stateSourceObject.Load(1, &t.ID) + stateSourceObject.Load(2, &t.SegTime) + stateSourceObject.Load(3, &t.RcvBufState) + stateSourceObject.Load(4, &t.SndBufState) + stateSourceObject.Load(5, &t.SACK) + stateSourceObject.Load(6, &t.Receiver) + stateSourceObject.Load(7, &t.Sender) +} + +func (ep *multiPortEndpoint) StateTypeName() string { + return "pkg/tcpip/stack.multiPortEndpoint" +} + +func (ep *multiPortEndpoint) StateFields() []string { + return []string{ + "demux", + "netProto", + "transProto", + "endpoints", + "flags", + } +} + +func (ep *multiPortEndpoint) beforeSave() {} + +// +checklocksignore +func (ep *multiPortEndpoint) StateSave(stateSinkObject state.Sink) { + ep.beforeSave() + stateSinkObject.Save(0, &ep.demux) + stateSinkObject.Save(1, &ep.netProto) + stateSinkObject.Save(2, &ep.transProto) + stateSinkObject.Save(3, &ep.endpoints) + stateSinkObject.Save(4, &ep.flags) +} + +func (ep *multiPortEndpoint) afterLoad() {} + +// +checklocksignore +func (ep *multiPortEndpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ep.demux) + stateSourceObject.Load(1, &ep.netProto) + stateSourceObject.Load(2, &ep.transProto) + stateSourceObject.Load(3, &ep.endpoints) + stateSourceObject.Load(4, &ep.flags) +} + +func (l *tupleList) StateTypeName() string { + return "pkg/tcpip/stack.tupleList" +} + +func (l *tupleList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *tupleList) beforeSave() {} + +// +checklocksignore +func (l *tupleList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *tupleList) afterLoad() {} + +// +checklocksignore +func (l *tupleList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *tupleEntry) StateTypeName() string { + return "pkg/tcpip/stack.tupleEntry" +} + +func (e *tupleEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *tupleEntry) beforeSave() {} + +// +checklocksignore +func (e *tupleEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *tupleEntry) afterLoad() {} + +// +checklocksignore +func (e *tupleEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*tuple)(nil)) + state.Register((*tupleID)(nil)) + state.Register((*conn)(nil)) + state.Register((*ConnTrack)(nil)) + state.Register((*bucket)(nil)) + state.Register((*unixTime)(nil)) + state.Register((*IPTables)(nil)) + state.Register((*Table)(nil)) + state.Register((*Rule)(nil)) + state.Register((*IPHeaderFilter)(nil)) + state.Register((*neighborEntryList)(nil)) + state.Register((*neighborEntryEntry)(nil)) + state.Register((*PacketBufferList)(nil)) + state.Register((*PacketBufferEntry)(nil)) + state.Register((*TransportEndpointID)(nil)) + state.Register((*GSOType)(nil)) + state.Register((*GSO)(nil)) + state.Register((*TransportEndpointInfo)(nil)) + state.Register((*TCPCubicState)(nil)) + state.Register((*TCPRACKState)(nil)) + state.Register((*TCPEndpointID)(nil)) + state.Register((*TCPFastRecoveryState)(nil)) + state.Register((*TCPReceiverState)(nil)) + state.Register((*TCPRTTState)(nil)) + state.Register((*TCPSenderState)(nil)) + state.Register((*TCPSACKInfo)(nil)) + state.Register((*RcvBufAutoTuneParams)(nil)) + state.Register((*TCPRcvBufState)(nil)) + state.Register((*TCPSndBufState)(nil)) + state.Register((*TCPEndpointStateInner)(nil)) + state.Register((*TCPEndpointState)(nil)) + state.Register((*multiPortEndpoint)(nil)) + state.Register((*tupleList)(nil)) + state.Register((*tupleEntry)(nil)) +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go deleted file mode 100644 index 21951d05a..000000000 --- a/pkg/tcpip/stack/stack_test.go +++ /dev/null @@ -1,4543 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package stack_test contains tests for the stack. It is in its own package so -// that the tests can also validate that all definitions needed to implement -// transport and network protocols are properly exported by the stack package. -package stack_test - -import ( - "bytes" - "fmt" - "math" - "net" - "sort" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -const ( - fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - fakeNetHeaderLen = 12 - fakeDefaultPrefixLen = 8 - - // fakeControlProtocol is used for control packets that represent - // destination port unreachable. - fakeControlProtocol tcpip.TransportProtocolNumber = 2 - - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65536 - - dstAddrOffset = 0 - srcAddrOffset = 1 - protocolNumberOffset = 2 -) - -func checkGetMainNICAddress(s *stack.Stack, nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, want tcpip.AddressWithPrefix) error { - if addr, err := s.GetMainNICAddress(nicID, proto); err != nil { - return fmt.Errorf("stack.GetMainNICAddress(%d, %d): %s", nicID, proto, err) - } else if addr != want { - return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, proto, addr, want) - } - return nil -} - -// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and -// received packets; the counts of all endpoints are aggregated in the protocol -// descriptor. -// -// Headers of this protocol are fakeNetHeaderLen bytes, but we currently only -// use the first three: destination address, source address, and transport -// protocol. They're all one byte fields to simplify parsing. -type fakeNetworkEndpoint struct { - stack.AddressableEndpointState - - mu struct { - sync.RWMutex - - enabled bool - forwarding bool - } - - nic stack.NetworkInterface - proto *fakeNetworkProtocol - dispatcher stack.TransportDispatcher -} - -func (f *fakeNetworkEndpoint) Enable() tcpip.Error { - f.mu.Lock() - defer f.mu.Unlock() - f.mu.enabled = true - return nil -} - -func (f *fakeNetworkEndpoint) Enabled() bool { - f.mu.RLock() - defer f.mu.RUnlock() - return f.mu.enabled -} - -func (f *fakeNetworkEndpoint) Disable() { - f.mu.Lock() - defer f.mu.Unlock() - f.mu.enabled = false -} - -func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.nic.MTU() - uint32(f.MaxHeaderLength()) -} - -func (*fakeNetworkEndpoint) DefaultTTL() uint8 { - return 123 -} - -func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { - if _, _, ok := f.proto.Parse(pkt); !ok { - return - } - - // Increment the received packet count in the protocol descriptor. - netHdr := pkt.NetworkHeader().View() - - dst := tcpip.Address(netHdr[dstAddrOffset:][:1]) - addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), stack.CanBePrimaryEndpoint) - if addressEndpoint == nil { - return - } - addressEndpoint.DecRef() - - f.proto.packetCount[int(dst[0])%len(f.proto.packetCount)]++ - - // Handle control packets. - if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen) - if !ok { - return - } - // DeleteFront invalidates slices. Make a copy before trimming. - nb := append([]byte(nil), hdr...) - pkt.Data().DeleteFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportError( - tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), - tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), - fakeNetNumber, - tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), - // Nothing checks the error. - nil, /* transport error */ - pkt, - ) - return - } - - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) -} - -func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.nic.MaxHeaderLength() + fakeNetHeaderLen -} - -func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return f.proto.Number() -} - -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { - // Increment the sent packet count in the protocol descriptor. - f.proto.sendPacketCount[int(r.RemoteAddress()[0])%len(f.proto.sendPacketCount)]++ - - // Add the protocol's header to the packet and send it to the link - // endpoint. - hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) - pkt.NetworkProtocolNumber = fakeNetNumber - hdr[dstAddrOffset] = r.RemoteAddress()[0] - hdr[srcAddrOffset] = r.LocalAddress()[0] - hdr[protocolNumberOffset] = byte(params.Protocol) - - if r.Loop()&stack.PacketLoop != 0 { - f.HandlePacket(pkt.Clone()) - } - if r.Loop()&stack.PacketOut == 0 { - return nil - } - - return f.nic.WritePacket(r, fakeNetNumber, pkt) -} - -// WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { - panic("not implemented") -} - -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -func (f *fakeNetworkEndpoint) Close() { - f.AddressableEndpointState.Cleanup() -} - -// Stats implements NetworkEndpoint. -func (*fakeNetworkEndpoint) Stats() stack.NetworkEndpointStats { - return &fakeNetworkEndpointStats{} -} - -var _ stack.NetworkEndpointStats = (*fakeNetworkEndpointStats)(nil) - -type fakeNetworkEndpointStats struct{} - -// IsNetworkEndpointStats implements stack.NetworkEndpointStats. -func (*fakeNetworkEndpointStats) IsNetworkEndpointStats() {} - -// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the -// number of packets sent and received via endpoints of this protocol. The index -// where packets are added is given by the packet's destination address MOD 10. -type fakeNetworkProtocol struct { - packetCount [10]int - sendPacketCount [10]int - defaultTTL uint8 -} - -func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { - return fakeNetNumber -} - -func (*fakeNetworkProtocol) MinimumPacketSize() int { - return fakeNetHeaderLen -} - -func (*fakeNetworkProtocol) DefaultPrefixLen() int { - return fakeDefaultPrefixLen -} - -func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { - return f.packetCount[int(intfAddr)%len(f.packetCount)] -} - -func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) -} - -func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { - e := &fakeNetworkEndpoint{ - nic: nic, - proto: f, - dispatcher: dispatcher, - } - e.AddressableEndpointState.Init(e) - return e -} - -func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { - switch v := option.(type) { - case *tcpip.DefaultTTLOption: - f.defaultTTL = uint8(*v) - return nil - default: - return &tcpip.ErrUnknownProtocolOption{} - } -} - -func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { - switch v := option.(type) { - case *tcpip.DefaultTTLOption: - *v = tcpip.DefaultTTLOption(f.defaultTTL) - return nil - default: - return &tcpip.ErrUnknownProtocolOption{} - } -} - -// Close implements NetworkProtocol.Close. -func (*fakeNetworkProtocol) Close() {} - -// Wait implements NetworkProtocol.Wait. -func (*fakeNetworkProtocol) Wait() {} - -// Parse implements NetworkProtocol.Parse. -func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen) - if !ok { - return 0, false, false - } - pkt.NetworkProtocolNumber = fakeNetNumber - return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true -} - -// Forwarding implements stack.ForwardingNetworkEndpoint. -func (f *fakeNetworkEndpoint) Forwarding() bool { - f.mu.RLock() - defer f.mu.RUnlock() - return f.mu.forwarding -} - -// SetForwarding implements stack.ForwardingNetworkEndpoint. -func (f *fakeNetworkEndpoint) SetForwarding(v bool) { - f.mu.Lock() - defer f.mu.Unlock() - f.mu.forwarding = v -} - -func fakeNetFactory(*stack.Stack) stack.NetworkProtocol { - return &fakeNetworkProtocol{} -} - -// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify -// that LinkEndpoint.Attach was called. -type linkEPWithMockedAttach struct { - stack.LinkEndpoint - attached bool -} - -// Attach implements stack.LinkEndpoint.Attach. -func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) { - l.LinkEndpoint.Attach(d) - l.attached = d != nil -} - -func (l *linkEPWithMockedAttach) isAttached() bool { - return l.attached -} - -// Checks to see if list contains an address. -func containsAddr(list []tcpip.ProtocolAddress, item tcpip.ProtocolAddress) bool { - for _, i := range list { - if i == item { - return true - } - } - - return false -} - -func TestNetworkReceive(t *testing.T) { - // Create a stack with the fake network protocol, one nic, and two - // addresses attached to it: 1 & 2. - ep := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Make sure packet with wrong address is not delivered. - buf[dstAddrOffset] = 3 - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } - if fakeNet.packetCount[2] != 0 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) - } - - // Make sure packet is delivered to first endpoint. - buf[dstAddrOffset] = 1 - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 0 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) - } - - // Make sure packet is delivered to second endpoint. - buf[dstAddrOffset] = 2 - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } - - // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } - - // Make sure packet that is too small is dropped. - buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - if fakeNet.packetCount[2] != 1 { - t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) - } -} - -func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) tcpip.Error { - r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - return err - } - defer r.Release() - return send(r, payload) -} - -func send(r *stack.Route, payload buffer.View) tcpip.Error { - return r.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: payload.ToVectorisedView(), - })) -} - -func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { - t.Helper() - ep.Drain() - if err := sendTo(s, addr, payload); err != nil { - t.Error("sendTo failed:", err) - } - if got, want := ep.Drain(), 1; got != want { - t.Errorf("sendTo packet count: got = %d, want %d", got, want) - } -} - -func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) { - t.Helper() - ep.Drain() - if err := send(r, payload); err != nil { - t.Error("send failed:", err) - } - if got, want := ep.Drain(), 1; got != want { - t.Errorf("send packet count: got = %d, want %d", got, want) - } -} - -func testFailingSend(t *testing.T, r *stack.Route, payload buffer.View, wantErr tcpip.Error) { - t.Helper() - if gotErr := send(r, payload); gotErr != wantErr { - t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) - } -} - -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View, wantErr tcpip.Error) { - t.Helper() - if gotErr := sendTo(s, addr, payload); gotErr != wantErr { - t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) - } -} - -func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { - t.Helper() - // testRecvInternal injects one packet, and we expect to receive it. - want := fakeNet.PacketCount(localAddrByte) + 1 - testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) -} - -func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { - t.Helper() - // testRecvInternal injects one packet, and we do NOT expect to receive it. - want := fakeNet.PacketCount(localAddrByte) - testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) -} - -func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { - t.Helper() - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if got := fakeNet.PacketCount(localAddrByte); got != want { - t.Errorf("receive packet count: got = %d, want %d", got, want) - } -} - -func TestNetworkSend(t *testing.T) { - // Create a stack with the fake network protocol, one nic, and one - // address: 1. The route table sends all packets through the only - // existing nic. - ep := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("NewNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - // Make sure that the link-layer endpoint received the outbound packet. - testSendTo(t, s, "\x03", ep, nil) -} - -func TestNetworkSendMultiRoute(t *testing.T) { - // Create a stack with the fake network protocol, two nics, and two - // addresses per nic, the first nic has odd address, the second one has - // even addresses. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - }) - } - - // Send a packet to an odd destination. - testSendTo(t, s, "\x05", ep1, nil) - - // Send a packet to an even destination. - testSendTo(t, s, "\x06", ep2, nil) -} - -func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { - r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - - defer r.Release() - - if r.LocalAddress() != expectedSrcAddr { - t.Fatalf("got Route.LocalAddress() = %s, want = %s", expectedSrcAddr, r.LocalAddress()) - } - - if r.RemoteAddress() != dstAddr { - t.Fatalf("got Route.RemoteAddress() = %s, want = %s", dstAddr, r.RemoteAddress()) - } -} - -func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) { - _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, &tcpip.ErrNoRoute{}) - } -} - -// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to -// a NetworkDispatcher when the NIC is created. -func TestAttachToLinkEndpointImmediately(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - nicOpts stack.NICOptions - }{ - { - name: "Create enabled NIC", - nicOpts: stack.NICOptions{Disabled: false}, - }, - { - name: "Create disabled NIC", - nicOpts: stack.NICOptions{Disabled: true}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - e := linkEPWithMockedAttach{ - LinkEndpoint: loopback.New(), - } - - if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err) - } - if !e.isAttached() { - t.Fatal("link endpoint not attached to a network dispatcher") - } - }) - } -} - -func TestDisableUnknownNIC(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - err := s.DisableNIC(1) - if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { - t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{}) - } -} - -func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - e := loopback.New() - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - checkNIC := func(enabled bool) { - t.Helper() - - allNICInfo := s.NICInfo() - nicInfo, ok := allNICInfo[nicID] - if !ok { - t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) - } else if nicInfo.Flags.Running != enabled { - t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled) - } - - if got := s.CheckNIC(nicID); got != enabled { - t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled) - } - } - - // NIC should initially report itself as disabled. - checkNIC(false) - - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - checkNIC(true) - - // If the NIC is not reporting a correct enabled status, we cannot trust the - // next check so end the test here. - if t.Failed() { - t.FailNow() - } - - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - checkNIC(false) -} - -func TestRemoveUnknownNIC(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - err := s.RemoveNIC(1) - if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { - t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{}) - } -} - -func TestRemoveNIC(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - e := linkEPWithMockedAttach{ - LinkEndpoint: loopback.New(), - } - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // NIC should be present in NICInfo and attached to a NetworkDispatcher. - allNICInfo := s.NICInfo() - if _, ok := allNICInfo[nicID]; !ok { - t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) - } - if !e.isAttached() { - t.Fatal("link endpoint not attached to a network dispatcher") - } - - // Removing a NIC should remove it from NICInfo and e should be detached from - // the NetworkDispatcher. - if err := s.RemoveNIC(nicID); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) - } - if nicInfo, ok := s.NICInfo()[nicID]; ok { - t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) - } - if e.isAttached() { - t.Error("link endpoint for removed NIC still attached to a network dispatcher") - } -} - -func TestRouteWithDownNIC(t *testing.T) { - tests := []struct { - name string - downFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error - upFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error - }{ - { - name: "Disabled NIC", - downFn: (*stack.Stack).DisableNIC, - upFn: (*stack.Stack).EnableNIC, - }, - - // Once a NIC is removed, it cannot be brought up. - { - name: "Removed NIC", - downFn: (*stack.Stack).RemoveNIC, - }, - } - - const unspecifiedNIC = 0 - const nicID1 = 1 - const nicID2 = 2 - const addr1 = tcpip.Address("\x01") - const addr2 = tcpip.Address("\x02") - const nic1Dst = tcpip.Address("\x05") - const nic2Dst = tcpip.Address("\x06") - - setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep1 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - - if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) - } - - ep2 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID2, ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - - if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, - {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, - }) - } - - return s, ep1, ep2 - } - - // Tests that routes through a down NIC are not used when looking up a route - // for a destination. - t.Run("Find", func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, _, _ := setup(t) - - // Test routes to odd address. - testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) - testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) - testRoute(t, s, nicID1, addr1, "\x05", addr1) - - // Test routes to even address. - testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) - testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) - testRoute(t, s, nicID2, addr2, "\x06", addr2) - - // Bringing NIC1 down should result in no routes to odd addresses. Routes to - // even addresses should continue to be available as NIC2 is still up. - if err := test.downFn(s, nicID1); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID1, err) - } - testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) - testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) - testNoRoute(t, s, nicID1, addr1, nic1Dst) - testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) - testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) - testRoute(t, s, nicID2, addr2, nic2Dst, addr2) - - // Bringing NIC2 down should result in no routes to even addresses. No - // route should be available to any address as routes to odd addresses - // were made unavailable by bringing NIC1 down above. - if err := test.downFn(s, nicID2); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID2, err) - } - testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) - testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) - testNoRoute(t, s, nicID1, addr1, nic1Dst) - testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) - testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) - testNoRoute(t, s, nicID2, addr2, nic2Dst) - - if upFn := test.upFn; upFn != nil { - // Bringing NIC1 up should make routes to odd addresses available - // again. Routes to even addresses should continue to be unavailable - // as NIC2 is still down. - if err := upFn(s, nicID1); err != nil { - t.Fatalf("test.upFn(_, %d): %s", nicID1, err) - } - testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) - testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) - testRoute(t, s, nicID1, addr1, nic1Dst, addr1) - testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) - testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) - testNoRoute(t, s, nicID2, addr2, nic2Dst) - } - }) - } - }) - - // Tests that writing a packet using a Route through a down NIC fails. - t.Run("WritePacket", func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, ep1, ep2 := setup(t) - - r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) - } - defer r1.Release() - - r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) - } - defer r2.Release() - - // If we failed to get routes r1 or r2, we cannot proceed with the test. - if t.Failed() { - t.FailNow() - } - - buf := buffer.View([]byte{1}) - testSend(t, r1, ep1, buf) - testSend(t, r2, ep2, buf) - - // Writes with Routes that use NIC1 after being brought down should fail. - if err := test.downFn(s, nicID1); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID1, err) - } - testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) - testSend(t, r2, ep2, buf) - - // Writes with Routes that use NIC2 after being brought down should fail. - if err := test.downFn(s, nicID2); err != nil { - t.Fatalf("test.downFn(_, %d): %s", nicID2, err) - } - testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) - testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) - - if upFn := test.upFn; upFn != nil { - // Writes with Routes that use NIC1 after being brought up should - // succeed. - // - // TODO(gvisor.dev/issue/1491): Should we instead completely - // invalidate all Routes that were bound to a NIC that was brought - // down at some point? - if err := upFn(s, nicID1); err != nil { - t.Fatalf("test.upFn(_, %d): %s", nicID1, err) - } - testSend(t, r1, ep1, buf) - testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) - } - }) - } - }) -} - -func TestRoutes(t *testing.T) { - // Create a stack with the fake network protocol, two nics, and two - // addresses per nic, the first nic has odd address, the second one has - // even addresses. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - }) - } - - // Test routes to odd address. - testRoute(t, s, 0, "", "\x05", "\x01") - testRoute(t, s, 0, "\x01", "\x05", "\x01") - testRoute(t, s, 1, "\x01", "\x05", "\x01") - testRoute(t, s, 0, "\x03", "\x05", "\x03") - testRoute(t, s, 1, "\x03", "\x05", "\x03") - - // Test routes to even address. - testRoute(t, s, 0, "", "\x06", "\x02") - testRoute(t, s, 0, "\x02", "\x06", "\x02") - testRoute(t, s, 2, "\x02", "\x06", "\x02") - testRoute(t, s, 0, "\x04", "\x06", "\x04") - testRoute(t, s, 2, "\x04", "\x06", "\x04") - - // Try to send to odd numbered address from even numbered ones, then - // vice-versa. - testNoRoute(t, s, 0, "\x02", "\x05") - testNoRoute(t, s, 2, "\x02", "\x05") - testNoRoute(t, s, 0, "\x04", "\x05") - testNoRoute(t, s, 2, "\x04", "\x05") - - testNoRoute(t, s, 0, "\x01", "\x06") - testNoRoute(t, s, 1, "\x01", "\x06") - testNoRoute(t, s, 0, "\x03", "\x06") - testNoRoute(t, s, 1, "\x03", "\x06") -} - -func TestAddressRemoval(t *testing.T) { - const localAddrByte byte = 0x01 - localAddr := tcpip.Address([]byte{localAddrByte}) - remoteAddr := tcpip.Address("\x02") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Send and receive packets, and verify they are received. - buf[dstAddrOffset] = localAddrByte - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // Remove the address, then check that send/receive doesn't work anymore. - if err := s.RemoveAddress(1, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) - - // Check that removing the same address fails. - err := s.RemoveAddress(1, localAddr) - if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{}) - } -} - -func TestAddressRemovalWithRouteHeld(t *testing.T) { - const localAddrByte byte = 0x01 - localAddr := tcpip.Address([]byte{localAddrByte}) - remoteAddr := tcpip.Address("\x02") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - buf := buffer.NewView(30) - - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - - // Send and receive packets, and verify they are received. - buf[dstAddrOffset] = localAddrByte - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSend(t, r, ep, nil) - testSendTo(t, s, remoteAddr, ep, nil) - - // Remove the address, then check that send/receive doesn't work anymore. - if err := s.RemoveAddress(1, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) - - // Check that removing the same address fails. - { - err := s.RemoveAddress(1, localAddr) - if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{}) - } - } -} - -func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) { - t.Helper() - info, ok := s.NICInfo()[nicID] - if !ok { - t.Fatalf("NICInfo() failed to find nicID=%d", nicID) - } - if len(addr) == 0 { - // No address given, verify that there is no address assigned to the NIC. - for _, a := range info.ProtocolAddresses { - if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { - t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, tcpip.AddressWithPrefix{}) - } - } - return - } - // Address given, verify the address is assigned to the NIC and no other - // address is. - found := false - for _, a := range info.ProtocolAddresses { - if a.Protocol == fakeNetNumber { - if a.AddressWithPrefix.Address == addr { - found = true - } else { - t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr) - } - } - } - if !found { - t.Errorf("verify address: couldn't find %s on the NIC", addr) - } -} - -func TestEndpointExpiration(t *testing.T) { - const ( - localAddrByte byte = 0x01 - remoteAddr tcpip.Address = "\x03" - noAddr tcpip.Address = "" - nicID tcpip.NICID = 1 - ) - localAddr := tcpip.Address([]byte{localAddrByte}) - - for _, promiscuous := range []bool{true, false} { - for _, spoofing := range []bool{true, false} { - t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - buf := buffer.NewView(30) - buf[dstAddrOffset] = localAddrByte - - if promiscuous { - if err := s.SetPromiscuousMode(nicID, true); err != nil { - t.Fatal("SetPromiscuousMode failed:", err) - } - } - - if spoofing { - if err := s.SetSpoofing(nicID, true); err != nil { - t.Fatal("SetSpoofing failed:", err) - } - } - - // 1. No Address yet, send should only work for spoofing, receive for - // promiscuous mode. - //----------------------- - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) - } - - // 2. Add Address, everything should work. - //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // 3. Remove the address, send should only work for spoofing, receive - // for promiscuous mode. - //----------------------- - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) - } - - // 4. Add Address back, everything should work again. - //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // 5. Take a reference to the endpoint by getting a route. Verify that - // we can still send/receive, including sending using the route. - //----------------------- - r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - testSend(t, r, ep, nil) - - // 6. Remove the address. Send should only work for spoofing, receive - // for promiscuous mode. - //----------------------- - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - testSend(t, r, ep, nil) - testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) - } - - // 7. Add Address back, everything should work again. - //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - testSend(t, r, ep, nil) - - // 8. Remove the route, sendTo/recv should still work. - //----------------------- - r.Release() - verifyAddress(t, s, nicID, localAddr) - testRecv(t, fakeNet, localAddrByte, ep, buf) - testSendTo(t, s, remoteAddr, ep, nil) - - // 9. Remove the address. Send should only work for spoofing, receive - // for promiscuous mode. - //----------------------- - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - verifyAddress(t, s, nicID, noAddr) - if promiscuous { - testRecv(t, fakeNet, localAddrByte, ep, buf) - } else { - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - } - if spoofing { - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) - } else { - testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) - } - }) - } - } -} - -func TestPromiscuousMode(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Write a packet, and check that it doesn't get delivered as we don't - // have a matching endpoint. - const localAddrByte byte = 0x01 - buf[dstAddrOffset] = localAddrByte - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - - // Set promiscuous mode, then check that packet is delivered. - if err := s.SetPromiscuousMode(1, true); err != nil { - t.Fatal("SetPromiscuousMode failed:", err) - } - testRecv(t, fakeNet, localAddrByte, ep, buf) - - // Check that we can't get a route as there is no local address. - _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, &tcpip.ErrNoRoute{}) - } - - // Set promiscuous mode to false, then check that packet can't be - // delivered anymore. - if err := s.SetPromiscuousMode(1, false); err != nil { - t.Fatal("SetPromiscuousMode failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) -} - -// TestExternalSendWithHandleLocal tests that the stack creates a non-local -// route when spoofing or promiscuous mode are enabled. -// -// This test makes sure that packets are transmitted from the stack. -func TestExternalSendWithHandleLocal(t *testing.T) { - const ( - unspecifiedNICID = 0 - nicID = 1 - - localAddr = tcpip.Address("\x01") - dstAddr = tcpip.Address("\x03") - ) - - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - - tests := []struct { - name string - configureStack func(*testing.T, *stack.Stack) - }{ - { - name: "Default", - configureStack: func(*testing.T, *stack.Stack) {}, - }, - { - name: "Spoofing", - configureStack: func(t *testing.T, s *stack.Stack) { - if err := s.SetSpoofing(nicID, true); err != nil { - t.Fatalf("s.SetSpoofing(%d, true): %s", nicID, err) - } - }, - }, - { - name: "Promiscuous", - configureStack: func(t *testing.T, s *stack.Stack) { - if err := s.SetPromiscuousMode(nicID, true); err != nil { - t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err) - } - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, handleLocal := range []bool{true, false} { - t.Run(fmt.Sprintf("HandleLocal=%t", handleLocal), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - HandleLocal: handleLocal, - }) - - ep := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) - } - - s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}}) - - test.configureStack(t, s) - - r, err := s.FindRoute(unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, err) - } - defer r.Release() - - if r.LocalAddress() != localAddr { - t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), localAddr) - } - if r.RemoteAddress() != dstAddr { - t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), dstAddr) - } - - if n := ep.Drain(); n != 0 { - t.Fatalf("got ep.Drain() = %d, want = 0", n) - } - if err := r.WritePacket(stack.NetworkHeaderParams{ - Protocol: fakeTransNumber, - TTL: 123, - TOS: stack.DefaultTOS, - }, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: buffer.NewView(10).ToVectorisedView(), - })); err != nil { - t.Fatalf("r.WritePacket(nil, _, _): %s", err) - } - if n := ep.Drain(); n != 1 { - t.Fatalf("got ep.Drain() = %d, want = 1", n) - } - }) - } - }) - } -} - -func TestSpoofingWithAddress(t *testing.T) { - localAddr := tcpip.Address("\x01") - nonExistentLocalAddr := tcpip.Address("\x02") - dstAddr := tcpip.Address("\x03") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // With address spoofing disabled, FindRoute does not permit an address - // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err == nil { - t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) - } - - // With address spoofing enabled, FindRoute permits any address to be used - // as the source. - if err := s.SetSpoofing(1, true); err != nil { - t.Fatal("SetSpoofing failed:", err) - } - r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - if r.LocalAddress() != nonExistentLocalAddr { - t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nonExistentLocalAddr) - } - if r.RemoteAddress() != dstAddr { - t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), dstAddr) - } - // Sending a packet works. - testSendTo(t, s, dstAddr, ep, nil) - testSend(t, r, ep, nil) - - // FindRoute should also work with a local address that exists on the NIC. - r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - if r.LocalAddress() != localAddr { - t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nonExistentLocalAddr) - } - if r.RemoteAddress() != dstAddr { - t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), dstAddr) - } - // Sending a packet using the route works. - testSend(t, r, ep, nil) -} - -func TestSpoofingNoAddress(t *testing.T) { - nonExistentLocalAddr := tcpip.Address("\x01") - dstAddr := tcpip.Address("\x02") - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // With address spoofing disabled, FindRoute does not permit an address - // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err == nil { - t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) - } - // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, nil, &tcpip.ErrNoRoute{}) - - // With address spoofing enabled, FindRoute permits any address to be used - // as the source. - if err := s.SetSpoofing(1, true); err != nil { - t.Fatal("SetSpoofing failed:", err) - } - r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatal("FindRoute failed:", err) - } - if r.LocalAddress() != nonExistentLocalAddr { - t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nonExistentLocalAddr) - } - if r.RemoteAddress() != dstAddr { - t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), dstAddr) - } - // Sending a packet works. - // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, ep, nil) -} - -func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - s.SetRouteTable([]tcpip.Route{}) - - // If there is no endpoint, it won't work. - { - _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok { - t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{}) - } - } - - protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} - if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) - } - r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) - } - if r.LocalAddress() != header.IPv4Any { - t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), header.IPv4Any) - } - - if r.RemoteAddress() != header.IPv4Broadcast { - t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), header.IPv4Broadcast) - } - - // If the NIC doesn't exist, it won't work. - { - _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok { - t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{}) - } - } -} - -func TestOutgoingBroadcastWithRouteTable(t *testing.T) { - defaultAddr := tcpip.AddressWithPrefix{Address: header.IPv4Any} - // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. - nic1Addr := tcpip.AddressWithPrefix{Address: "\xc0\xa8\x01\x3a", PrefixLen: 24} - nic1Gateway := testutil.MustParse4("192.168.1.1") - // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. - nic2Addr := tcpip.AddressWithPrefix{Address: "\x0a\x0a\x0a\x05", PrefixLen: 24} - nic2Gateway := testutil.MustParse4("10.10.10.1") - - // Create a new stack with two NICs. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - if err := s.CreateNIC(2, ep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} - if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) - } - - nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} - if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) - } - - // Set the initial route table. - rt := []tcpip.Route{ - {Destination: nic1Addr.Subnet(), NIC: 1}, - {Destination: nic2Addr.Subnet(), NIC: 2}, - {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2}, - {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1}, - } - s.SetRouteTable(rt) - - // When an interface is given, the route for a broadcast goes through it. - r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) - } - if r.LocalAddress() != nic1Addr.Address { - t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nic1Addr.Address) - } - - if r.RemoteAddress() != header.IPv4Broadcast { - t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), header.IPv4Broadcast) - } - - // When an interface is not given, it consults the route table. - // 1. Case: Using the default route. - r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) - } - if r.LocalAddress() != nic2Addr.Address { - t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nic2Addr.Address) - } - - if r.RemoteAddress() != header.IPv4Broadcast { - t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), header.IPv4Broadcast) - } - - // 2. Case: Having an explicit route for broadcast will select that one. - rt = append( - []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, - }, - rt..., - ) - s.SetRouteTable(rt) - r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) - } - if r.LocalAddress() != nic1Addr.Address { - t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nic1Addr.Address) - } - - if r.RemoteAddress() != header.IPv4Broadcast { - t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), header.IPv4Broadcast) - } -} - -func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { - for _, tc := range []struct { - name string - routeNeeded bool - address tcpip.Address - }{ - // IPv4 multicast address range: 224.0.0.0 - 239.255.255.255 - // <=> 0xe0.0x00.0x00.0x00 - 0xef.0xff.0xff.0xff - {"IPv4 Multicast 1", false, "\xe0\x00\x00\x00"}, - {"IPv4 Multicast 2", false, "\xef\xff\xff\xff"}, - {"IPv4 Unicast 1", true, "\xdf\xff\xff\xff"}, - {"IPv4 Unicast 2", true, "\xf0\x00\x00\x00"}, - {"IPv4 Unicast 3", true, "\x00\x00\x00\x00"}, - - // IPv6 multicast address is 0xff[8] + flags[4] + scope[4] + groupId[112] - {"IPv6 Multicast 1", false, "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Multicast 2", false, "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Multicast 3", false, "\xff\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - - // IPv6 link-local address starts with fe80::/10. - {"IPv6 Unicast Link-Local 1", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Link-Local 2", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, - {"IPv6 Unicast Link-Local 3", false, "\xfe\x80\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff"}, - {"IPv6 Unicast Link-Local 4", false, "\xfe\xbf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Link-Local 5", false, "\xfe\xbf\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - - // IPv6 addresses that are neither multicast nor link-local. - {"IPv6 Unicast Not Link-Local 1", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 2", true, "\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"}, - {"IPv6 Unicast Not Link-local 3", true, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 4", true, "\xfe\xc0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 5", true, "\xfe\xdf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 6", true, "\xfd\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, - } { - t.Run(tc.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - s.SetRouteTable([]tcpip.Route{}) - - var anyAddr tcpip.Address - if len(tc.address) == header.IPv4AddressSize { - anyAddr = header.IPv4Any - } else { - anyAddr = header.IPv6Any - } - - var want tcpip.Error = &tcpip.ErrNetworkUnreachable{} - if tc.routeNeeded { - want = &tcpip.ErrNoRoute{} - } - - // If there is no endpoint, it won't work. - if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want) - } - - if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err) - } - - if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { - // Route table is empty but we need a route, this should cause an error. - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, &tcpip.ErrNoRoute{}) - } - } else { - if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", anyAddr, tc.address, fakeNetNumber, err) - } - if r.LocalAddress() != anyAddr { - t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress(), anyAddr) - } - if r.RemoteAddress() != tc.address { - t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress(), tc.address) - } - } - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want { - t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", anyAddr, tc.address, fakeNetNumber, err, want) - } - }) - } -} - -func TestNetworkOption(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{}, - }) - - opt := tcpip.DefaultTTLOption(5) - if err := s.SetNetworkProtocolOption(fakeNetNumber, &opt); err != nil { - t.Fatalf("s.SetNetworkProtocolOption(%d, &%T(%d)): %s", fakeNetNumber, opt, opt, err) - } - - var optGot tcpip.DefaultTTLOption - if err := s.NetworkProtocolOption(fakeNetNumber, &optGot); err != nil { - t.Fatalf("s.NetworkProtocolOption(%d, &%T): %s", fakeNetNumber, optGot, err) - } - - if opt != optGot { - t.Errorf("got optGot = %d, want = %d", optGot, opt) - } -} - -func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { - const nicID = 1 - - for _, addrLen := range []int{4, 16} { - t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) { - for canBe := 0; canBe < 3; canBe++ { - t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) { - for never := 0; never < 3; never++ { - t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - // Insert <canBe> primary and <never> never-primary addresses. - // Each one will add a network endpoint to the NIC. - primaryAddrAdded := make(map[tcpip.AddressWithPrefix]struct{}) - for i := 0; i < canBe+never; i++ { - var behavior stack.PrimaryEndpointBehavior - if i < canBe { - behavior = stack.CanBePrimaryEndpoint - } else { - behavior = stack.NeverPrimaryEndpoint - } - // Add an address and in case of a primary one include a - // prefixLen. - address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen)) - if behavior == stack.CanBePrimaryEndpoint { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: addrLen * 8, - }, - } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err) - } - // Remember the address/prefix. - primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{} - } else { - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err) - } - } - } - // Check that GetMainNICAddress returns an address if at least - // one primary address was added. In that case make sure the - // address/prefixLen matches what we added. - gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber) - if err != nil { - t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) - } - if len(primaryAddrAdded) == 0 { - // No primary addresses present. - if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, wantAddr) - } - } else { - // At least one primary address was added, verify the returned - // address is in the list of primary addresses we added. - if _, ok := primaryAddrAdded[gotAddr]; !ok { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, primaryAddrAdded) - } - } - }) - } - }) - } - }) - } -} - -func TestGetMainNICAddressErrors(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - // Sanity check with a successful call. - if addr, err := s.GetMainNICAddress(nicID, ipv4.ProtocolNumber); err != nil { - t.Errorf("s.GetMainNICAddress(%d, %d): %s", nicID, ipv4.ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, ipv4.ProtocolNumber, addr, want) - } - - const unknownNICID = nicID + 1 - switch addr, err := s.GetMainNICAddress(unknownNICID, ipv4.ProtocolNumber); err.(type) { - case *tcpip.ErrUnknownNICID: - default: - t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrUnknownNICID)", unknownNICID, ipv4.ProtocolNumber, addr, err) - } - - // ARP is not an addressable network endpoint. - switch addr, err := s.GetMainNICAddress(nicID, arp.ProtocolNumber); err.(type) { - case *tcpip.ErrNotSupported: - default: - t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrNotSupported)", nicID, arp.ProtocolNumber, addr, err) - } - - const unknownProtocolNumber = 1234 - switch addr, err := s.GetMainNICAddress(nicID, unknownProtocolNumber); err.(type) { - case *tcpip.ErrUnknownProtocol: - default: - t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrUnknownProtocol)", nicID, unknownProtocolNumber, addr, err) - } -} - -func TestGetMainNICAddressAddRemove(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - for _, tc := range []struct { - name string - address tcpip.Address - prefixLen int - }{ - {"IPv4", "\x01\x01\x01\x01", 24}, - {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 116}, - } { - t.Run(tc.name, func(t *testing.T) { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tc.address, - PrefixLen: tc.prefixLen, - }, - } - if err := s.AddProtocolAddress(1, protocolAddress); err != nil { - t.Fatal("AddProtocolAddress failed:", err) - } - - // Check that we get the right initial address and prefix length. - if err := checkGetMainNICAddress(s, 1, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { - t.Fatal(err) - } - - if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { - t.Fatal("RemoveAddress failed:", err) - } - - // Check that we get no address after removal. - if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - }) - } -} - -// Simple network address generator. Good for 255 addresses. -type addressGenerator struct{ cnt byte } - -func (g *addressGenerator) next(addrLen int) tcpip.Address { - g.cnt++ - return tcpip.Address(bytes.Repeat([]byte{g.cnt}, addrLen)) -} - -func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) { - t.Helper() - - if len(gotAddresses) != len(expectedAddresses) { - t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses)) - } - - sort.Slice(gotAddresses, func(i, j int) bool { - return gotAddresses[i].AddressWithPrefix.Address < gotAddresses[j].AddressWithPrefix.Address - }) - sort.Slice(expectedAddresses, func(i, j int) bool { - return expectedAddresses[i].AddressWithPrefix.Address < expectedAddresses[j].AddressWithPrefix.Address - }) - - for i, gotAddr := range gotAddresses { - expectedAddr := expectedAddresses[i] - if gotAddr != expectedAddr { - t.Errorf("got address = %+v, wanted = %+v", gotAddr, expectedAddr) - } - } -} - -func TestAddAddress(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - var addrGen addressGenerator - expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) - for _, addrLen := range []int{4, 16} { - address := addrGen.next(addrLen) - if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { - t.Fatalf("AddAddress(address=%s) failed: %s", address, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, - }) - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddProtocolAddress(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - var addrGen addressGenerator - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)) - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) - } - expectedAddresses = append(expectedAddresses, protocolAddress) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addrLenRange := []int{4, 16} - behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)) - var addrGen addressGenerator - for _, addrLen := range addrLenRange { - for _, behavior := range behaviorRange { - address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, - }) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange)) - var addrGen addressGenerator - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - for _, behavior := range behaviorRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) - } - expectedAddresses = append(expectedAddresses, protocolAddress) - } - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestCreateNICWithOptions(t *testing.T) { - type callArgsAndExpect struct { - nicID tcpip.NICID - opts stack.NICOptions - err tcpip.Error - } - - tests := []struct { - desc string - calls []callArgsAndExpect - }{ - { - desc: "DuplicateNICID", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{Name: "eth1"}, - err: nil, - }, - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{Name: "eth2"}, - err: &tcpip.ErrDuplicateNICID{}, - }, - }, - }, - { - desc: "DuplicateName", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{Name: "lo"}, - err: nil, - }, - { - nicID: tcpip.NICID(2), - opts: stack.NICOptions{Name: "lo"}, - err: &tcpip.ErrDuplicateNICID{}, - }, - }, - }, - { - desc: "Unnamed", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{}, - err: nil, - }, - { - nicID: tcpip.NICID(2), - opts: stack.NICOptions{}, - err: nil, - }, - }, - }, - { - desc: "UnnamedDuplicateNICID", - calls: []callArgsAndExpect{ - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{}, - err: nil, - }, - { - nicID: tcpip.NICID(1), - opts: stack.NICOptions{}, - err: &tcpip.ErrDuplicateNICID{}, - }, - }, - }, - } - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - s := stack.New(stack.Options{}) - ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") - for _, call := range test.calls { - if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { - t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) - } - } - }) - } -} - -func TestNICStats(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - nics := []struct { - addr tcpip.Address - txByteCount int - rxByteCount int - }{ - { - addr: "\x01", - txByteCount: 30, - rxByteCount: 10, - }, - { - addr: "\x02", - txByteCount: 50, - rxByteCount: 20, - }, - } - - var txBytesTotal, rxBytesTotal, txPacketsTotal, rxPacketsTotal int - for i, nic := range nics { - nicid := tcpip.NICID(i) - ep := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { - t.Fatal("CreateNIC failed: ", err) - } - if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { - t.Fatal("AddAddress failed:", err) - } - - { - subnet, err := tcpip.NewSubnet(nic.addr, "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicid}}) - } - - nicStats := s.NICInfo()[nicid].Stats - - // Inbound packet. - rxBuffer := buffer.NewView(nic.rxByteCount) - ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: rxBuffer.ToVectorisedView(), - })) - if got, want := nicStats.Rx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) - } - if got, want := nicStats.Rx.Bytes.Value(), uint64(nic.rxByteCount); got != want { - t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) - } - rxPacketsTotal++ - rxBytesTotal += nic.rxByteCount - - // Outbound packet. - txBuffer := buffer.NewView(nic.txByteCount) - actualTxLength := nic.txByteCount + fakeNetHeaderLen - if err := sendTo(s, nic.addr, txBuffer); err != nil { - t.Fatal("sendTo failed: ", err) - } - want := ep.Drain() - if got := nicStats.Tx.Packets.Value(); got != uint64(want) { - t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) - } - if got, want := nicStats.Tx.Bytes.Value(), uint64(actualTxLength); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } - txPacketsTotal += want - txBytesTotal += actualTxLength - } - - // Now verify that each NIC stats was correctly aggregated at the stack level. - if got, want := s.Stats().NICs.Rx.Packets.Value(), uint64(rxPacketsTotal); got != want { - t.Errorf("got s.Stats().NIC.Rx.Packets.Value() = %d, want = %d", got, want) - } - if got, want := s.Stats().NICs.Rx.Bytes.Value(), uint64(rxBytesTotal); got != want { - t.Errorf("got s.Stats().Rx.Bytes.Value() = %d, want = %d", got, want) - } - if got, want := s.Stats().NICs.Tx.Packets.Value(), uint64(txPacketsTotal); got != want { - t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) - } - if got, want := s.Stats().NICs.Tx.Bytes.Value(), uint64(txBytesTotal); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } -} - -// TestNICContextPreservation tests that you can read out via stack.NICInfo the -// Context data you pass via NICContext.Context in stack.CreateNICWithOptions. -func TestNICContextPreservation(t *testing.T) { - var ctx *int - tests := []struct { - name string - opts stack.NICOptions - want stack.NICContext - }{ - { - "context_set", - stack.NICOptions{Context: ctx}, - ctx, - }, - { - "context_not_set", - stack.NICOptions{}, - nil, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{}) - id := tcpip.NICID(1) - ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") - if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { - t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) - } - nicinfos := s.NICInfo() - nicinfo, ok := nicinfos[id] - if !ok { - t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos) - } - if got, want := nicinfo.Context == test.want, true; got != want { - t.Fatalf("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) - } - }) - } -} - -// TestNICAutoGenLinkLocalAddr tests the auto-generation of IPv6 link-local -// addresses. -func TestNICAutoGenLinkLocalAddr(t *testing.T) { - const nicID = 1 - - var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte - n, err := rand.Read(secretKey[:]) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n) - } - - nicNameFunc := func(_ tcpip.NICID, name string) string { - return name - } - - tests := []struct { - name string - nicName string - autoGen bool - linkAddr tcpip.LinkAddress - iidOpts ipv6.OpaqueInterfaceIdentifierOptions - shouldGen bool - expectedAddr tcpip.Address - }{ - { - name: "Disabled", - nicName: "nic1", - autoGen: false, - linkAddr: linkAddr1, - shouldGen: false, - }, - { - name: "Disabled without OIID options", - nicName: "nic1", - autoGen: false, - linkAddr: linkAddr1, - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:], - }, - shouldGen: false, - }, - - // Tests for EUI64 based addresses. - { - name: "EUI64 Enabled", - autoGen: true, - linkAddr: linkAddr1, - shouldGen: true, - expectedAddr: header.LinkLocalAddr(linkAddr1), - }, - { - name: "EUI64 Empty MAC", - autoGen: true, - shouldGen: false, - }, - { - name: "EUI64 Invalid MAC", - autoGen: true, - linkAddr: "\x01\x02\x03", - shouldGen: false, - }, - { - name: "EUI64 Multicast MAC", - autoGen: true, - linkAddr: "\x01\x02\x03\x04\x05\x06", - shouldGen: false, - }, - { - name: "EUI64 Unspecified MAC", - autoGen: true, - linkAddr: "\x00\x00\x00\x00\x00\x00", - shouldGen: false, - }, - - // Tests for Opaque IID based addresses. - { - name: "OIID Enabled", - nicName: "nic1", - autoGen: true, - linkAddr: linkAddr1, - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("nic1", 0, secretKey[:]), - }, - // These are all cases where we would not have generated a - // link-local address if opaque IIDs were disabled. - { - name: "OIID Empty MAC and empty nicName", - autoGen: true, - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:1], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("", 0, secretKey[:1]), - }, - { - name: "OIID Invalid MAC", - nicName: "test", - autoGen: true, - linkAddr: "\x01\x02\x03", - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:2], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("test", 0, secretKey[:2]), - }, - { - name: "OIID Multicast MAC", - nicName: "test2", - autoGen: true, - linkAddr: "\x01\x02\x03\x04\x05\x06", - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - SecretKey: secretKey[:3], - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("test2", 0, secretKey[:3]), - }, - { - name: "OIID Unspecified MAC and nil SecretKey", - nicName: "test3", - autoGen: true, - linkAddr: "\x00\x00\x00\x00\x00\x00", - iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: nicNameFunc, - }, - shouldGen: true, - expectedAddr: header.LinkLocalAddrWithOpaqueIID("test3", 0, nil), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: test.autoGen, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: test.iidOpts, - })}, - } - - e := channel.New(0, 1280, test.linkAddr) - s := stack.New(opts) - nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) - } - - // A new disabled NIC should not have any address, even if auto generation - // was enabled. - allStackAddrs := s.AllAddresses() - allNICAddrs, ok := allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) - } - - // Enabling the NIC should attempt auto-generation of a link-local - // address. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - var expectedMainAddr tcpip.AddressWithPrefix - if test.shouldGen { - expectedMainAddr = tcpip.AddressWithPrefix{ - Address: test.expectedAddr, - PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, - } - - // Should have auto-generated an address and resolved immediately (DAD - // is disabled). - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, expectedMainAddr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } else { - // Should not have auto-generated an address. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address") - default: - } - } - - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, expectedMainAddr); err != nil { - t.Fatal(err) - } - - // Disabling the NIC should remove the auto-generated address. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - }) - } -} - -// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are -// not auto-generated for loopback NICs. -func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { - const nicID = 1 - const nicName = "nicName" - - tests := []struct { - name string - opaqueIIDOpts ipv6.OpaqueInterfaceIdentifierOptions - }{ - { - name: "IID From MAC", - opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{}, - }, - { - name: "Opaque IID", - opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: true, - OpaqueIIDOpts: test.opaqueIIDOpts, - })}, - } - - e := loopback.New() - s := stack.New(opts) - nicOpts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - }) - } -} - -// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 -// link-local addresses will only be assigned after the DAD process resolves. -func TestNICAutoGenAddrDoesDAD(t *testing.T) { - const nicID = 1 - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - dadConfigs := stack.DefaultDADConfigurations() - clock := faketime.NewManualClock() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenLinkLocal: true, - NDPDisp: &ndpDisp, - DADConfigs: dadConfigs, - })}, - Clock: clock, - } - - e := channel.New(int(dadConfigs.DupAddrDetectTransmits), 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - // Address should not be considered bound to the - // NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - linkLocalAddr := header.LinkLocalAddr(linkAddr1) - - // Wait for DAD to resolve. - clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - default: - // We should get a resolution event after 1s (default time to - // resolve as per default NDP configurations). Waiting for that - // resolution time + an extra 1s without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil { - t.Fatal(err) - } -} - -// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected -// when an address's kind gets "promoted" to permanent from permanentExpired. -func TestNewPEBOnPromotionToPermanent(t *testing.T) { - const nicID = 1 - - pebs := []stack.PrimaryEndpointBehavior{ - stack.NeverPrimaryEndpoint, - stack.CanBePrimaryEndpoint, - stack.FirstPrimaryEndpoint, - } - - for _, pi := range pebs { - for _, ps := range pebs { - t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - // Add a permanent address with initial - // PrimaryEndpointBehavior (peb), pi. If pi is - // NeverPrimaryEndpoint, the address should not - // be returned by a call to GetMainNICAddress; - // else, it should. - const address1 = tcpip.Address("\x01") - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) - } - addr, err := s.GetMainNICAddress(nicID, fakeNetNumber) - if err != nil { - t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) - } - if pi == stack.NeverPrimaryEndpoint { - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr, want) - - } - } else if addr.Address != address1 { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr.Address, address1) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatalf("NewSubnet failed: %v", err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // Take a route through the address so its ref - // count gets incremented and does not actually - // get deleted when RemoveAddress is called - // below. This is because we want to test that a - // new peb is respected when an address gets - // "promoted" to permanent from a - // permanentExpired kind. - const address2 = tcpip.Address("\x02") - r, err := s.FindRoute(nicID, address1, address2, fakeNetNumber, false) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, address1, address2, fakeNetNumber, err) - } - defer r.Release() - if err := s.RemoveAddress(nicID, address1); err != nil { - t.Fatalf("RemoveAddress(%d, %s): %s", nicID, address1, err) - } - - // - // At this point, the address should still be - // known by the NIC, but have its - // kind = permanentExpired. - // - - // Add some other address with peb set to - // FirstPrimaryEndpoint. - const address3 = tcpip.Address("\x03") - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err) - - } - - // Add back the address we removed earlier and - // make sure the new peb was respected. - // (The address should just be promoted now). - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) - } - var primaryAddrs []tcpip.Address - for _, pa := range s.NICInfo()[nicID].ProtocolAddresses { - primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address) - } - var expectedList []tcpip.Address - switch ps { - case stack.FirstPrimaryEndpoint: - expectedList = []tcpip.Address{ - "\x01", - "\x03", - } - case stack.CanBePrimaryEndpoint: - expectedList = []tcpip.Address{ - "\x03", - "\x01", - } - case stack.NeverPrimaryEndpoint: - expectedList = []tcpip.Address{ - "\x03", - } - } - if !cmp.Equal(primaryAddrs, expectedList) { - t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList) - } - - // Once we remove the other address, if the new - // peb, ps, was NeverPrimaryEndpoint, no address - // should be returned by a call to - // GetMainNICAddress; else, our original address - // should be returned. - if err := s.RemoveAddress(nicID, address3); err != nil { - t.Fatalf("RemoveAddress(%d, %s): %s", nicID, address3, err) - } - addr, err = s.GetMainNICAddress(nicID, fakeNetNumber) - if err != nil { - t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) - } - if ps == stack.NeverPrimaryEndpoint { - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr, want) - } - } else { - if addr.Address != address1 { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr.Address, address1) - } - } - }) - } - } -} - -func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { - const ( - nicID = 1 - lifetimeSeconds = 9999 - ) - - var ( - linkLocalAddr1 = testutil.MustParse6("fe80::1") - linkLocalAddr2 = testutil.MustParse6("fe80::2") - linkLocalMulticastAddr = testutil.MustParse6("ff02::1") - uniqueLocalAddr1 = testutil.MustParse6("fc00::1") - uniqueLocalAddr2 = testutil.MustParse6("fd00::2") - globalAddr1 = testutil.MustParse6("a000::1") - globalAddr2 = testutil.MustParse6("a000::2") - globalAddr3 = testutil.MustParse6("a000::3") - ipv4MappedIPv6Addr1 = testutil.MustParse6("::ffff:0.0.0.1") - ipv4MappedIPv6Addr2 = testutil.MustParse6("::ffff:0.0.0.2") - toredoAddr1 = testutil.MustParse6("2001::1") - toredoAddr2 = testutil.MustParse6("2001::2") - ipv6ToIPv4Addr1 = testutil.MustParse6("2002::1") - ipv6ToIPv4Addr2 = testutil.MustParse6("2002::2") - ) - - prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1) - - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempGlobalAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr1.Address).Address - tempGlobalAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr2.Address).Address - - // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test. - tests := []struct { - name string - slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix - nicAddrs []tcpip.Address - slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix - remoteAddr tcpip.Address - expectedLocalAddr tcpip.Address - }{ - // Test Rule 1 of RFC 6724 section 5 (prefer same address). - { - name: "Same Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, - remoteAddr: globalAddr1, - expectedLocalAddr: globalAddr1, - }, - { - name: "Same Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1}, - remoteAddr: globalAddr1, - expectedLocalAddr: globalAddr1, - }, - { - name: "Same Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, - remoteAddr: linkLocalAddr1, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Same Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, - remoteAddr: linkLocalAddr1, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Same Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1}, - remoteAddr: uniqueLocalAddr1, - expectedLocalAddr: uniqueLocalAddr1, - }, - { - name: "Same Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1}, - remoteAddr: uniqueLocalAddr1, - expectedLocalAddr: uniqueLocalAddr1, - }, - - // Test Rule 2 of RFC 6724 section 5 (prefer appropriate scope). - { - name: "Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, - remoteAddr: globalAddr2, - expectedLocalAddr: globalAddr1, - }, - { - name: "Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, - remoteAddr: globalAddr2, - expectedLocalAddr: globalAddr1, - }, - { - name: "Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, - remoteAddr: linkLocalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, - remoteAddr: linkLocalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local most preferred for link local multicast (last address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, - remoteAddr: linkLocalMulticastAddr, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local most preferred for link local multicast (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, - remoteAddr: linkLocalMulticastAddr, - expectedLocalAddr: linkLocalAddr1, - }, - - // Test Rule 6 of 6724 section 5 (prefer matching label). - { - name: "Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1}, - remoteAddr: uniqueLocalAddr2, - expectedLocalAddr: uniqueLocalAddr1, - }, - { - name: "Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1}, - remoteAddr: uniqueLocalAddr2, - expectedLocalAddr: uniqueLocalAddr1, - }, - { - name: "Toredo most preferred (first address)", - nicAddrs: []tcpip.Address{toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1}, - remoteAddr: toredoAddr2, - expectedLocalAddr: toredoAddr1, - }, - { - name: "Toredo most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1}, - remoteAddr: toredoAddr2, - expectedLocalAddr: toredoAddr1, - }, - { - name: "6To4 most preferred (first address)", - nicAddrs: []tcpip.Address{ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1}, - remoteAddr: ipv6ToIPv4Addr2, - expectedLocalAddr: ipv6ToIPv4Addr1, - }, - { - name: "6To4 most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, uniqueLocalAddr1, toredoAddr1, ipv6ToIPv4Addr1}, - remoteAddr: ipv6ToIPv4Addr2, - expectedLocalAddr: ipv6ToIPv4Addr1, - }, - { - name: "IPv4 mapped IPv6 most preferred (first address)", - nicAddrs: []tcpip.Address{ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1}, - remoteAddr: ipv4MappedIPv6Addr2, - expectedLocalAddr: ipv4MappedIPv6Addr1, - }, - { - name: "IPv4 mapped IPv6 most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1, ipv4MappedIPv6Addr1}, - remoteAddr: ipv4MappedIPv6Addr2, - expectedLocalAddr: ipv4MappedIPv6Addr1, - }, - - // Test Rule 7 of RFC 6724 section 5 (prefer temporary addresses). - { - name: "Temp Global most preferred (last address)", - slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - remoteAddr: globalAddr2, - expectedLocalAddr: tempGlobalAddr1, - }, - { - name: "Temp Global most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - slaacPrefixForTempAddrAfterNICAddrAdd: prefix1, - remoteAddr: globalAddr2, - expectedLocalAddr: tempGlobalAddr1, - }, - - // Test Rule 8 of RFC 6724 section 5 (use longest matching prefix). - { - name: "Longest prefix matched most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr2, globalAddr1}, - remoteAddr: globalAddr3, - expectedLocalAddr: globalAddr2, - }, - { - name: "Longest prefix matched most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, globalAddr2}, - remoteAddr: globalAddr3, - expectedLocalAddr: globalAddr2, - }, - - // Test returning the endpoint that is closest to the front when - // candidate addresses are "equal" from the perspective of RFC 6724 - // section 5. - { - name: "Unique Local for Global", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2}, - remoteAddr: globalAddr2, - expectedLocalAddr: uniqueLocalAddr1, - }, - { - name: "Link Local for Global", - nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - remoteAddr: globalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Link Local for Unique Local", - nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - remoteAddr: uniqueLocalAddr2, - expectedLocalAddr: linkLocalAddr1, - }, - { - name: "Temp Global for Global", - slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, - slaacPrefixForTempAddrAfterNICAddrAdd: prefix2, - remoteAddr: globalAddr1, - expectedLocalAddr: tempGlobalAddr2, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDispatcher{}, - })}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) { - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) - } - - for _, a := range test.nicAddrs { - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { - t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) - } - } - - if test.slaacPrefixForTempAddrAfterNICAddrAdd != (tcpip.AddressWithPrefix{}) { - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrAfterNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) - } - - if t.Failed() { - t.FailNow() - } - - netEP, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } - - addressableEndpoint, ok := netEP.(stack.AddressableEndpoint) - if !ok { - t.Fatal("network endpoint is not addressable") - } - - addressEP := addressableEndpoint.AcquireOutgoingPrimaryAddress(test.remoteAddr, false /* allowExpired */) - if addressEP == nil { - t.Fatal("expected a non-nil address endpoint") - } - defer addressEP.DecRef() - - if got := addressEP.AddressWithPrefix().Address; got != test.expectedLocalAddr { - t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr) - } - }) - } -} - -func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { - const nicID = 1 - broadcastAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: header.IPv4Broadcast, - PrefixLen: 32, - }, - } - - e := loopback.New() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - { - allStackAddrs := s.AllAddresses() - if allNICAddrs, ok := allStackAddrs[nicID]; !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } else if containsAddr(allNICAddrs, broadcastAddr) { - t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) - } - } - - // Enabling the NIC should add the IPv4 broadcast address. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - { - allStackAddrs := s.AllAddresses() - if allNICAddrs, ok := allStackAddrs[nicID]; !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } else if !containsAddr(allNICAddrs, broadcastAddr) { - t.Fatalf("got allNICAddrs = %+v, want = %+v", allNICAddrs, broadcastAddr) - } - } - - // Disabling the NIC should remove the IPv4 broadcast address. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - - { - allStackAddrs := s.AllAddresses() - if allNICAddrs, ok := allStackAddrs[nicID]; !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } else if containsAddr(allNICAddrs, broadcastAddr) { - t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) - } - } -} - -// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6 -// address after leaving its solicited node multicast address does not result in -// an error. -func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - }) - e := channel.New(10, 1280, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err) - } - - // The NIC should have joined addr1's solicited node multicast address. - snmc := header.SolicitedNodeAddr(addr1) - in, err := s.IsInGroup(nicID, snmc) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) - } - if !in { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc) - } - - if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil { - t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err) - } - in, err = s.IsInGroup(nicID, snmc) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) - } - if in { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc) - } - - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) - } -} - -func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - addr tcpip.Address - }{ - { - name: "IPv6 All-Nodes", - proto: header.IPv6ProtocolNumber, - addr: header.IPv6AllNodesMulticastAddress, - }, - { - name: "IPv4 All-Systems", - proto: header.IPv4ProtocolNumber, - addr: header.IPv4AllSystems, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := loopback.New() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - }) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - // Should not be in the multicast group yet because the NIC has not been - // enabled yet. - if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) - } else if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) - } - - // The all-nodes multicast group should be joined when the NIC is enabled. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) - } else if !isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) - } - - // The multicast group should be left when the NIC is disabled. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - - if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) - } else if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) - } - - // The all-nodes multicast group should be joined when the NIC is enabled. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - - if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) - } else if !isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) - } - - // Leaving the group before disabling the NIC should not cause an error. - if err := s.LeaveGroup(test.proto, nicID, test.addr); err != nil { - t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", test.proto, nicID, test.addr, err) - } - - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - - if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) - } else if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) - } - }) - } -} - -// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC -// was disabled have DAD performed on them when the NIC is enabled. -func TestDoDADWhenNICEnabled(t *testing.T) { - const dadTransmits = 1 - const retransmitTimer = time.Second - const nicID = 1 - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - clock := faketime.NewManualClock() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - } - - e := channel.New(dadTransmits, 1280, linkAddr1) - s := stack.New(opts) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) - } - - addr := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: llAddr1, - PrefixLen: 128, - }, - } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) - } - - // Address should be in the list of all addresses. - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - - // Address should be tentative so it should not be a main address. - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Enabling the NIC should start DAD for the address. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - - // Address should not be considered bound to the NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - - // Wait for DAD to resolve. - clock.Advance(dadTransmits * retransmitTimer) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, &stack.DADSucceeded{}); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("timed out waiting for DAD resolution") - } - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil { - t.Fatal(err) - } - - // Enabling the NIC again should be a no-op. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { - t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) - } - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil { - t.Fatal(err) - } -} - -func TestStackReceiveBufferSizeOption(t *testing.T) { - const sMin = stack.MinBufferSize - testCases := []struct { - name string - rs tcpip.ReceiveBufferSizeOption - err tcpip.Error - }{ - // Invalid configurations. - {"min_below_zero", tcpip.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"min_zero", tcpip.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"default_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - {"default_above_max", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"max_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - - // Valid Configurations - {"in_ascending_order", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, - {"all_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, - {"min_default_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, - {"default_max_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - s := stack.New(stack.Options{}) - defer s.Close() - if err := s.SetOption(tc.rs); err != tc.err { - t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err) - } - var rs tcpip.ReceiveBufferSizeOption - if tc.err == nil { - if err := s.Option(&rs); err != nil { - t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err) - } - if got, want := rs, tc.rs; got != want { - t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) - } - } - }) - } -} - -func TestStackSendBufferSizeOption(t *testing.T) { - const sMin = stack.MinBufferSize - testCases := []struct { - name string - ss tcpip.SendBufferSizeOption - err tcpip.Error - }{ - // Invalid configurations. - {"min_below_zero", tcpip.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"min_zero", tcpip.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"default_below_min", tcpip.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - {"default_above_max", tcpip.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"max_below_min", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - - // Valid Configurations - {"in_ascending_order", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, - {"all_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, - {"min_default_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, - {"default_max_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - s := stack.New(stack.Options{}) - defer s.Close() - err := s.SetOption(tc.ss) - if diff := cmp.Diff(tc.err, err); diff != "" { - t.Fatalf("unexpected error from s.SetOption(%+v), (-want, +got):\n%s", tc.ss, diff) - } - if tc.err == nil { - var ss tcpip.SendBufferSizeOption - if err := s.Option(&ss); err != nil { - t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err) - } - if got, want := ss, tc.ss; got != want { - t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) - } - } - }) - } -} - -func TestOutgoingSubnetBroadcast(t *testing.T) { - const ( - unspecifiedNICID = 0 - nicID1 = 1 - ) - - defaultAddr := tcpip.AddressWithPrefix{ - Address: header.IPv4Any, - PrefixLen: 0, - } - defaultSubnet := defaultAddr.Subnet() - ipv4Addr := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 24, - } - ipv4Subnet := ipv4Addr.Subnet() - ipv4SubnetBcast := ipv4Subnet.Broadcast() - ipv4Gateway := testutil.MustParse4("192.168.1.1") - ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 31, - } - ipv4Subnet31 := ipv4AddrPrefix31.Subnet() - ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() - ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 32, - } - ipv4Subnet32 := ipv4AddrPrefix32.Subnet() - ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() - ipv6Addr := tcpip.AddressWithPrefix{ - Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - PrefixLen: 64, - } - ipv6Subnet := ipv6Addr.Subnet() - ipv6SubnetBcast := ipv6Subnet.Broadcast() - remNetAddr := tcpip.AddressWithPrefix{ - Address: "\x64\x0a\x7b\x18", - PrefixLen: 24, - } - remNetSubnet := remNetAddr.Subnet() - remNetSubnetBcast := remNetSubnet.Broadcast() - - tests := []struct { - name string - nicAddr tcpip.ProtocolAddress - routes []tcpip.Route - remoteAddr tcpip.Address - expectedLocalAddress tcpip.Address - expectedRemoteAddress tcpip.Address - expectedRemoteLinkAddress tcpip.LinkAddress - expectedNextHop tcpip.Address - expectedNetProto tcpip.NetworkProtocolNumber - expectedLoop stack.PacketLooping - }{ - // Broadcast to a locally attached subnet populates the broadcast MAC. - { - name: "IPv4 Broadcast to local subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4Addr, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet, - NIC: nicID1, - }, - }, - remoteAddr: ipv4SubnetBcast, - expectedLocalAddress: ipv4Addr.Address, - expectedRemoteAddress: ipv4SubnetBcast, - expectedRemoteLinkAddress: header.EthernetBroadcastAddress, - expectedNetProto: header.IPv4ProtocolNumber, - expectedLoop: stack.PacketOut | stack.PacketLoop, - }, - // Broadcast to a locally attached /31 subnet does not populate the - // broadcast MAC. - { - name: "IPv4 Broadcast to local /31 subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4AddrPrefix31, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet31, - NIC: nicID1, - }, - }, - remoteAddr: ipv4Subnet31Bcast, - expectedLocalAddress: ipv4AddrPrefix31.Address, - expectedRemoteAddress: ipv4Subnet31Bcast, - expectedNetProto: header.IPv4ProtocolNumber, - expectedLoop: stack.PacketOut, - }, - // Broadcast to a locally attached /32 subnet does not populate the - // broadcast MAC. - { - name: "IPv4 Broadcast to local /32 subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4AddrPrefix32, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet32, - NIC: nicID1, - }, - }, - remoteAddr: ipv4Subnet32Bcast, - expectedLocalAddress: ipv4AddrPrefix32.Address, - expectedRemoteAddress: ipv4Subnet32Bcast, - expectedNetProto: header.IPv4ProtocolNumber, - expectedLoop: stack.PacketOut, - }, - // IPv6 has no notion of a broadcast. - { - name: "IPv6 'Broadcast' to local subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: ipv6Addr, - }, - routes: []tcpip.Route{ - { - Destination: ipv6Subnet, - NIC: nicID1, - }, - }, - remoteAddr: ipv6SubnetBcast, - expectedLocalAddress: ipv6Addr.Address, - expectedRemoteAddress: ipv6SubnetBcast, - expectedNetProto: header.IPv6ProtocolNumber, - expectedLoop: stack.PacketOut, - }, - // Broadcast to a remote subnet in the route table is send to the next-hop - // gateway. - { - name: "IPv4 Broadcast to remote subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4Addr, - }, - routes: []tcpip.Route{ - { - Destination: remNetSubnet, - Gateway: ipv4Gateway, - NIC: nicID1, - }, - }, - remoteAddr: remNetSubnetBcast, - expectedLocalAddress: ipv4Addr.Address, - expectedRemoteAddress: remNetSubnetBcast, - expectedNextHop: ipv4Gateway, - expectedNetProto: header.IPv4ProtocolNumber, - expectedLoop: stack.PacketOut, - }, - // Broadcast to an unknown subnet follows the default route. Note that this - // is essentially just routing an unknown destination IP, because w/o any - // subnet prefix information a subnet broadcast address is just a normal IP. - { - name: "IPv4 Broadcast to unknown subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4Addr, - }, - routes: []tcpip.Route{ - { - Destination: defaultSubnet, - Gateway: ipv4Gateway, - NIC: nicID1, - }, - }, - remoteAddr: remNetSubnetBcast, - expectedLocalAddress: ipv4Addr.Address, - expectedRemoteAddress: remNetSubnetBcast, - expectedNextHop: ipv4Gateway, - expectedNetProto: header.IPv4ProtocolNumber, - expectedLoop: stack.PacketOut, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - }) - ep := channel.New(0, defaultMTU, "") - ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID1, ep); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) - } - - s.SetRouteTable(test.routes) - - var netProto tcpip.NetworkProtocolNumber - switch l := len(test.remoteAddr); l { - case header.IPv4AddressSize: - netProto = header.IPv4ProtocolNumber - case header.IPv6AddressSize: - netProto = header.IPv6ProtocolNumber - default: - t.Fatalf("got unexpected address length = %d bytes", l) - } - - r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err) - } - if r.LocalAddress() != test.expectedLocalAddress { - t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.expectedLocalAddress) - } - if r.RemoteAddress() != test.expectedRemoteAddress { - t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress(), test.expectedRemoteAddress) - } - if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress { - t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress) - } - if r.NextHop() != test.expectedNextHop { - t.Errorf("got r.NextHop() = %s, want = %s", r.NextHop(), test.expectedNextHop) - } - if r.NetProto() != test.expectedNetProto { - t.Errorf("got r.NetProto() = %d, want = %d", r.NetProto(), test.expectedNetProto) - } - if r.Loop() != test.expectedLoop { - t.Errorf("got r.Loop() = %x, want = %x", r.Loop(), test.expectedLoop) - } - }) - } -} - -func TestResolveWith(t *testing.T) { - const ( - unspecifiedNICID = 0 - nicID = 1 - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, - }) - ep := channel.New(0, defaultMTU, "") - ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - addr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address([]byte{192, 168, 1, 58}), - PrefixLen: 24, - }, - } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err) - } - - s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) - - remoteAddr := tcpip.Address([]byte{192, 168, 1, 59}) - r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err) - } - defer r.Release() - - // Should initially require resolution. - if !r.IsResolutionRequired() { - t.Fatal("got r.IsResolutionRequired() = false, want = true") - } - - // Manually resolving the route should no longer require resolution. - r.ResolveWith("\x01") - if r.IsResolutionRequired() { - t.Fatal("got r.IsResolutionRequired() = true, want = false") - } -} - -// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its -// associated address is removed should not cause a panic. -func TestRouteReleaseAfterAddrRemoval(t *testing.T) { - const ( - nicID = 1 - localAddr = tcpip.Address("\x01") - remoteAddr = tcpip.Address("\x02") - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - ep := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) - } - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - r, err := s.FindRoute(nicID, localAddr, remoteAddr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, localAddr, remoteAddr, fakeNetNumber, err) - } - // Should not panic. - defer r.Release() - - // Check that removing the same address fails. - if err := s.RemoveAddress(nicID, localAddr); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, localAddr, err) - } -} - -func TestGetNetworkEndpoint(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - protoFactory stack.NetworkProtocolFactory - protoNum tcpip.NetworkProtocolNumber - }{ - { - name: "IPv4", - protoFactory: ipv4.NewProtocol, - protoNum: ipv4.ProtocolNumber, - }, - { - name: "IPv6", - protoFactory: ipv6.NewProtocol, - protoNum: ipv6.ProtocolNumber, - }, - } - - factories := make([]stack.NetworkProtocolFactory, 0, len(tests)) - for _, test := range tests { - factories = append(factories, test.protoFactory) - } - - s := stack.New(stack.Options{ - NetworkProtocols: factories, - }) - - if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep, err := s.GetNetworkEndpoint(nicID, test.protoNum) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, test.protoNum, err) - } - - if got := ep.NetworkProtocolNumber(); got != test.protoNum { - t.Fatalf("got ep.NetworkProtocolNumber() = %d, want = %d", got, test.protoNum) - } - }) - } -} - -func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - - if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: "\x01", - PrefixLen: 8, - }, - } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err) - } - - // Check that we get the right initial address and prefix length. - if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { - t.Fatal(err) - } - - // Should still get the address when the NIC is diabled. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("DisableNIC(%d): %s", nicID, err) - } - if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { - t.Fatal(err) - } -} - -// TestAddRoute tests Stack.AddRoute -func TestAddRoute(t *testing.T) { - s := stack.New(stack.Options{}) - - subnet1, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - - subnet2, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - - expected := []tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet2, Gateway: "\x00", NIC: 1}, - } - - // Initialize the route table with one route. - s.SetRouteTable([]tcpip.Route{expected[0]}) - - // Add another route. - s.AddRoute(expected[1]) - - rt := s.GetRouteTable() - if got, want := len(rt), len(expected); got != want { - t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) - } - for i, route := range rt { - if got, want := route, expected[i]; got != want { - t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) - } - } -} - -// TestRemoveRoutes tests Stack.RemoveRoutes -func TestRemoveRoutes(t *testing.T) { - s := stack.New(stack.Options{}) - - addressToRemove := tcpip.Address("\x01") - subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01") - if err != nil { - t.Fatal(err) - } - - subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01") - if err != nil { - t.Fatal(err) - } - - subnet3, err := tcpip.NewSubnet("\x02", "\x02") - if err != nil { - t.Fatal(err) - } - - // Initialize the route table with three routes. - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - {Destination: subnet2, Gateway: "\x00", NIC: 1}, - {Destination: subnet3, Gateway: "\x00", NIC: 1}, - }) - - // Remove routes with the specific address. - s.RemoveRoutes(func(r tcpip.Route) bool { - return r.Destination.ID() == addressToRemove - }) - - expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}} - rt := s.GetRouteTable() - if got, want := len(rt), len(expected); got != want { - t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) - } - for i, route := range rt { - if got, want := route, expected[i]; got != want { - t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) - } - } -} - -func TestFindRouteWithForwarding(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - - nic1Addr = tcpip.Address("\x01") - nic2Addr = tcpip.Address("\x02") - remoteAddr = tcpip.Address("\x03") - ) - - type netCfg struct { - proto tcpip.NetworkProtocolNumber - factory stack.NetworkProtocolFactory - nic1Addr tcpip.Address - nic2Addr tcpip.Address - remoteAddr tcpip.Address - } - - fakeNetCfg := netCfg{ - proto: fakeNetNumber, - factory: fakeNetFactory, - nic1Addr: nic1Addr, - nic2Addr: nic2Addr, - remoteAddr: remoteAddr, - } - - globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16()) - globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16()) - - ipv6LinkLocalNIC1WithGlobalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: llAddr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: globalIPv6Addr1, - } - ipv6GlobalNIC1WithLinkLocalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: llAddr1, - remoteAddr: llAddr2, - } - ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - } - - tests := []struct { - name string - - netCfg netCfg - forwardingEnabled bool - - addrNIC tcpip.NICID - localAddr tcpip.Address - - findRouteErr tcpip.Error - dependentOnForwarding bool - }{ - { - name: "forwarding disabled and localAddr not on specified NIC but route from different NIC", - netCfg: fakeNetCfg, - forwardingEnabled: false, - addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr not on specified NIC but route from different NIC", - netCfg: fakeNetCfg, - forwardingEnabled: true, - addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and localAddr on specified NIC but route from different NIC", - netCfg: fakeNetCfg, - forwardingEnabled: false, - addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr on specified NIC but route from different NIC", - netCfg: fakeNetCfg, - forwardingEnabled: true, - addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: nil, - dependentOnForwarding: true, - }, - { - name: "forwarding disabled and localAddr on specified NIC and route from same NIC", - netCfg: fakeNetCfg, - forwardingEnabled: false, - addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr on specified NIC and route from same NIC", - netCfg: fakeNetCfg, - forwardingEnabled: true, - addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and localAddr not on specified NIC but route from same NIC", - netCfg: fakeNetCfg, - forwardingEnabled: false, - addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr not on specified NIC but route from same NIC", - netCfg: fakeNetCfg, - forwardingEnabled: true, - addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and localAddr on same NIC as route", - netCfg: fakeNetCfg, - forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr on same NIC as route", - netCfg: fakeNetCfg, - forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and localAddr on different NIC as route", - netCfg: fakeNetCfg, - forwardingEnabled: false, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and localAddr on different NIC as route", - netCfg: fakeNetCfg, - forwardingEnabled: true, - localAddr: fakeNetCfg.nic1Addr, - findRouteErr: nil, - dependentOnForwarding: true, - }, - { - name: "forwarding disabled and specified NIC only has link-local addr with route on different NIC", - netCfg: ipv6LinkLocalNIC1WithGlobalRemote, - forwardingEnabled: false, - addrNIC: nicID1, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and specified NIC only has link-local addr with route on different NIC", - netCfg: ipv6LinkLocalNIC1WithGlobalRemote, - forwardingEnabled: true, - addrNIC: nicID1, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and link-local local addr with route on different NIC", - netCfg: ipv6LinkLocalNIC1WithGlobalRemote, - forwardingEnabled: false, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and link-local local addr with route on same NIC", - netCfg: ipv6LinkLocalNIC1WithGlobalRemote, - forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, - findRouteErr: &tcpip.ErrNoRoute{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and global local addr with route on same NIC", - netCfg: ipv6LinkLocalNIC1WithGlobalRemote, - forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and link-local local addr with route on same NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalRemote, - forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and link-local local addr with route on same NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalRemote, - forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and global local addr with link-local remote on different NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalRemote, - forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, - findRouteErr: &tcpip.ErrNetworkUnreachable{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and global local addr with link-local remote on different NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalRemote, - forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, - findRouteErr: &tcpip.ErrNetworkUnreachable{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and global local addr with link-local multicast remote on different NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, - forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, - findRouteErr: &tcpip.ErrNetworkUnreachable{}, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and global local addr with link-local multicast remote on different NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, - forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, - findRouteErr: &tcpip.ErrNetworkUnreachable{}, - dependentOnForwarding: false, - }, - { - name: "forwarding disabled and global local addr with link-local multicast remote on same NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, - forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - { - name: "forwarding enabled and global local addr with link-local multicast remote on same NIC", - netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, - forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, - findRouteErr: nil, - dependentOnForwarding: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{test.netCfg.factory}, - }) - - ep1 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s:", nicID1, err) - } - - ep2 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID2, ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err) - } - - if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err) - } - - if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) - } - - if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) - } - - s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) - - r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) - if err == nil { - defer r.Release() - } - if diff := cmp.Diff(test.findRouteErr, err); diff != "" { - t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff) - } - - if test.findRouteErr != nil { - return - } - - if r.LocalAddress() != test.localAddr { - t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddr) - } - if r.RemoteAddress() != test.netCfg.remoteAddr { - t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), test.netCfg.remoteAddr) - } - - if t.Failed() { - t.FailNow() - } - - // Sending a packet should always go through NIC2 since we only install a - // route to test.netCfg.remoteAddr through NIC2. - data := buffer.View([]byte{1, 2, 3, 4}) - if err := send(r, data); err != nil { - t.Fatalf("send(_, _): %s", err) - } - if n := ep1.Drain(); n != 0 { - t.Errorf("got %d unexpected packets from ep1", n) - } - pkt, ok := ep2.Read() - if !ok { - t.Fatal("packet not sent through ep2") - } - if pkt.Route.LocalAddress != test.localAddr { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr) - } - if pkt.Route.RemoteAddress != test.netCfg.remoteAddr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr) - } - - if !test.forwardingEnabled || !test.dependentOnForwarding { - return - } - - // Disabling forwarding when the route is dependent on forwarding being - // enabled should make the route invalid. - if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, false); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", test.netCfg.proto, err) - } - { - err := send(r, data) - if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { - t.Fatalf("got send(_, _) = %s, want = %s", err, &tcpip.ErrInvalidEndpointState{}) - } - } - if n := ep1.Drain(); n != 0 { - t.Errorf("got %d unexpected packets from ep1", n) - } - if n := ep2.Drain(); n != 0 { - t.Errorf("got %d unexpected packets from ep2", n) - } - }) - } -} - -func TestWritePacketToRemote(t *testing.T) { - const nicID = 1 - const MTU = 1280 - e := channel.New(1, MTU, linkAddr1) - s := stack.New(stack.Options{}) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("CreateNIC(%d) = %s", nicID, err) - } - tests := []struct { - name string - protocol tcpip.NetworkProtocolNumber - payload []byte - }{ - { - name: "SuccessIPv4", - protocol: header.IPv4ProtocolNumber, - payload: []byte{1, 2, 3, 4}, - }, - { - name: "SuccessIPv6", - protocol: header.IPv6ProtocolNumber, - payload: []byte{5, 6, 7, 8}, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil { - t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err) - } - - pkt, ok := e.Read() - if got, want := ok, true; got != want { - t.Fatalf("e.Read() = %t, want %t", got, want) - } - if got, want := pkt.Proto, test.protocol; got != want { - t.Fatalf("pkt.Proto = %d, want %d", got, want) - } - if pkt.Route.RemoteLinkAddress != linkAddr2 { - t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2) - } - if diff := cmp.Diff(pkt.Pkt.Data().AsRange().ToOwnedView(), buffer.View(test.payload)); diff != "" { - t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) - } - }) - } - - t.Run("InvalidNICID", func(t *testing.T) { - err := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()) - if _, ok := err.(*tcpip.ErrUnknownDevice); !ok { - t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", err, &tcpip.ErrUnknownDevice{}) - } - pkt, ok := e.Read() - if got, want := ok, false; got != want { - t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want) - } - }) -} - -func TestClearNeighborCacheOnNICDisable(t *testing.T) { - const ( - nicID = 1 - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - ) - - var ( - ipv4Addr = testutil.MustParse4("1.2.3.4") - ipv6Addr = testutil.MustParse6("102:304:102:304:102:304:102:304") - ) - - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - Clock: clock, - }) - e := channel.New(0, 0, "") - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - addrs := []struct { - proto tcpip.NetworkProtocolNumber - addr tcpip.Address - }{ - { - proto: ipv4.ProtocolNumber, - addr: ipv4Addr, - }, - { - proto: ipv6.ProtocolNumber, - addr: ipv6Addr, - }, - } - for _, addr := range addrs { - if err := s.AddStaticNeighbor(nicID, addr.proto, addr.addr, linkAddr); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, addr.proto, addr.addr, linkAddr, err) - } - - if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) - } else if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAt: clock.Now()}}, - neighbors, - ); diff != "" { - t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff) - } - } - - // Disabling the NIC should clear the neighbor table. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - for _, addr := range addrs { - if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) - } else if len(neighbors) != 0 { - t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) - } - } - - // Enabling the NIC should have an empty neighbor table. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - for _, addr := range addrs { - if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { - t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) - } else if len(neighbors) != 0 { - t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) - } - } -} - -func TestGetLinkAddressErrors(t *testing.T) { - const ( - nicID = 1 - unknownNICID = nicID + 1 - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - { - err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil) - if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrUnknownNICID{}) - } - } - { - err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil) - if _, ok := err.(*tcpip.ErrNotSupported); !ok { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrNotSupported{}) - } - } -} - -func TestStaticGetLinkAddress(t *testing.T) { - const ( - nicID = 1 - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - }) - e := channel.New(0, 0, "") - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - addr tcpip.Address - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "IPv4", - proto: ipv4.ProtocolNumber, - addr: header.IPv4Broadcast, - expectedLinkAddr: header.EthernetBroadcastAddress, - }, - { - name: "IPv6", - proto: ipv6.ProtocolNumber, - addr: header.IPv6AllNodesMulticastAddress, - expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ch := make(chan stack.LinkResolutionResult, 1) - if err := s.GetLinkAddress(nicID, test.addr, "", test.proto, func(r stack.LinkResolutionResult) { - ch <- r - }); err != nil { - t.Fatalf("s.GetLinkAddress(%d, %s, '', %d, _): %s", nicID, test.addr, test.proto, err) - } - - if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Err: nil}, <-ch); diff != "" { - t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) - } - }) - } -} diff --git a/pkg/tcpip/stack/stack_unsafe_state_autogen.go b/pkg/tcpip/stack/stack_unsafe_state_autogen.go new file mode 100644 index 000000000..758ab3457 --- /dev/null +++ b/pkg/tcpip/stack/stack_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package stack diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go deleted file mode 100644 index 45b09110d..000000000 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ /dev/null @@ -1,446 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "io/ioutil" - "math" - "math/rand" - "strconv" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/ports" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - - testSrcAddrV4 = "\x0a\x00\x00\x01" - testDstAddrV4 = "\x0a\x00\x00\x02" - - testDstPort = 1234 - testSrcPort = 4096 -) - -type testContext struct { - linkEps map[tcpip.NICID]*channel.Endpoint - s *stack.Stack - wq waiter.Queue -} - -// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. -func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - linkEps := make(map[tcpip.NICID]*channel.Endpoint) - for _, linkEpID := range linkEpIDs { - channelEp := channel.New(256, mtu, "") - if err := s.CreateNIC(linkEpID, channelEp); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - linkEps[linkEpID] = channelEp - - if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil { - t.Fatalf("AddAddress IPv4 failed: %s", err) - } - - if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil { - t.Fatalf("AddAddress IPv6 failed: %s", err) - } - } - - s.SetRouteTable([]tcpip.Route{ - {Destination: header.IPv4EmptySubnet, NIC: 1}, - {Destination: header.IPv6EmptySubnet, NIC: 1}, - }) - - return &testContext{ - s: s, - linkEps: linkEps, - } -} - -type headers struct { - srcPort uint16 - dstPort uint16 -} - -func newPayload() []byte { - b := make([]byte, 30+rand.Intn(100)) - for i := range b { - b[i] = byte(rand.Intn(256)) - } - return b -} - -func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { - buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - payloadStart := len(buf) - len(payload) - copy(buf[payloadStart:], payload) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TOS: 0x80, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(udp.ProtocolNumber), - SrcAddr: testSrcAddrV4, - DstAddr: testDstAddrV4, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - }) - c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt) -} - -func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - TransportProtocol: udp.ProtocolNumber, - HopLimit: 65, - SrcAddr: testSrcAddrV6, - DstAddr: testDstAddrV6, - }) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - }) - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt) -} - -func TestTransportDemuxerRegister(t *testing.T) { - for _, test := range []struct { - name string - proto tcpip.NetworkProtocolNumber - want tcpip.Error - }{ - {"failure", ipv6.ProtocolNumber, &tcpip.ErrUnknownProtocol{}}, - {"success", ipv4.ProtocolNumber, nil}, - } { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - tEP, ok := ep.(stack.TransportEndpoint) - if !ok { - t.Fatalf("%T does not implement stack.TransportEndpoint", ep) - } - if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { - t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) - } - }) - } -} - -func TestTransportDemuxerRegisterMultiple(t *testing.T) { - type test struct { - flags ports.Flags - want tcpip.Error - } - for _, subtest := range []struct { - name string - tests []test - }{ - {"zeroFlags", []test{ - {ports.Flags{}, nil}, - {ports.Flags{}, &tcpip.ErrPortInUse{}}, - }}, - {"multibindFlags", []test{ - // Allow multiple registrations same TransportEndpointID with multibind flags. - {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, - {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, - // Disallow registration w/same ID for a non-multibindflag. - {ports.Flags{TupleOnly: true}, &tcpip.ErrPortInUse{}}, - }}, - } { - t.Run(subtest.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - var eps []tcpip.Endpoint - for idx, test := range subtest.tests { - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatal(err) - } - eps = append(eps, ep) - tEP, ok := ep.(stack.TransportEndpoint) - if !ok { - t.Fatalf("%T does not implement stack.TransportEndpoint", ep) - } - id := stack.TransportEndpointID{LocalPort: 1} - if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber}, udp.ProtocolNumber, id, tEP, test.flags, 0), test.want; got != want { - t.Fatalf("test index: %d, s.RegisterTransportEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber, _, _, %+v, 0) = %s, want %s", idx, test.flags, got, want) - } - } - for _, ep := range eps { - ep.Close() - } - }) - } -} - -// TestBindToDeviceDistribution injects varied packets on input devices and checks that -// the distribution of packets received matches expectations. -func TestBindToDeviceDistribution(t *testing.T) { - type endpointSockopts struct { - reuse bool - bindToDevice tcpip.NICID - } - tcs := []struct { - name string - // endpoints will received the inject packets. - endpoints []endpointSockopts - // wantDistributions is the want ratio of packets received on each - // endpoint for each NIC on which packets are injected. - wantDistributions map[tcpip.NICID][]float64 - }{ - { - name: "BindPortReuse", - // 5 endpoints that all have reuse set. - endpoints: []endpointSockopts{ - {reuse: true, bindToDevice: 0}, - {reuse: true, bindToDevice: 0}, - {reuse: true, bindToDevice: 0}, - {reuse: true, bindToDevice: 0}, - {reuse: true, bindToDevice: 0}, - }, - wantDistributions: map[tcpip.NICID][]float64{ - // Injected packets on dev0 get distributed evenly. - 1: {0.2, 0.2, 0.2, 0.2, 0.2}, - }, - }, - { - name: "BindToDevice", - // 3 endpoints with various bindings. - endpoints: []endpointSockopts{ - {reuse: false, bindToDevice: 1}, - {reuse: false, bindToDevice: 2}, - {reuse: false, bindToDevice: 3}, - }, - wantDistributions: map[tcpip.NICID][]float64{ - // Injected packets on dev0 go only to the endpoint bound to dev0. - 1: {1, 0, 0}, - // Injected packets on dev1 go only to the endpoint bound to dev1. - 2: {0, 1, 0}, - // Injected packets on dev2 go only to the endpoint bound to dev2. - 3: {0, 0, 1}, - }, - }, - { - name: "ReuseAndBindToDevice", - // 6 endpoints with various bindings. - endpoints: []endpointSockopts{ - {reuse: true, bindToDevice: 1}, - {reuse: true, bindToDevice: 1}, - {reuse: true, bindToDevice: 2}, - {reuse: true, bindToDevice: 2}, - {reuse: true, bindToDevice: 2}, - {reuse: true, bindToDevice: 0}, - }, - wantDistributions: map[tcpip.NICID][]float64{ - // Injected packets on dev0 get distributed among endpoints bound to - // dev0. - 1: {0.5, 0.5, 0, 0, 0, 0}, - // Injected packets on dev1 get distributed among endpoints bound to - // dev1 or unbound. - 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, - // Injected packets on dev999 go only to the unbound. - 1000: {0, 0, 0, 0, 0, 1}, - }, - }, - } - protos := map[string]tcpip.NetworkProtocolNumber{ - "IPv4": ipv4.ProtocolNumber, - "IPv6": ipv6.ProtocolNumber, - } - - for _, test := range tcs { - for protoName, protoNum := range protos { - for device, wantDistribution := range test.wantDistributions { - t.Run(test.name+protoName+"-"+strconv.Itoa(int(device)), func(t *testing.T) { - // Create the NICs. - var devices []tcpip.NICID - for d := range test.wantDistributions { - devices = append(devices, d) - } - c := newDualTestContextMultiNIC(t, defaultMTU, devices) - - // Create endpoints and bind each to a NIC, sometimes reusing ports. - eps := make(map[tcpip.Endpoint]int) - pollChannel := make(chan tcpip.Endpoint) - for i, endpoint := range test.endpoints { - // Try to receive the data. - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - t.Cleanup(func() { - wq.EventUnregister(&we) - close(ch) - }) - - var err tcpip.Error - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, protoNum, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - t.Cleanup(ep.Close) - eps[ep] = i - - go func(ep tcpip.Endpoint) { - for range ch { - pollChannel <- ep - } - }(ep) - - ep.SocketOptions().SetReusePort(endpoint.reuse) - if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil { - t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err) - } - - var dstAddr tcpip.Address - switch protoNum { - case ipv4.ProtocolNumber: - dstAddr = testDstAddrV4 - case ipv6.ProtocolNumber: - dstAddr = testDstAddrV6 - default: - t.Fatalf("unexpected protocol number: %d", protoNum) - } - if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil { - t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err) - } - } - - // Send packets across a range of ports, checking that packets from - // the same source port are always demultiplexed to the same - // destination endpoint. - npackets := 10_000 - nports := 1_000 - if got, want := len(test.endpoints), len(wantDistribution); got != want { - t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) - } - endpoints := make(map[uint16]tcpip.Endpoint) - stats := make(map[tcpip.Endpoint]int) - for i := 0; i < npackets; i++ { - // Send a packet. - port := uint16(i % nports) - payload := newPayload() - hdrs := &headers{ - srcPort: testSrcPort + port, - dstPort: testDstPort, - } - switch protoNum { - case ipv4.ProtocolNumber: - c.sendV4Packet(payload, hdrs, device) - case ipv6.ProtocolNumber: - c.sendV6Packet(payload, hdrs, device) - default: - t.Fatalf("unexpected protocol number: %d", protoNum) - } - - ep := <-pollChannel - if _, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != nil { - t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err) - } - stats[ep]++ - if i < nports { - endpoints[uint16(i)] = ep - } else { - // Check that all packets from one client are handled by the same - // socket. - if want, got := endpoints[port], ep; want != got { - t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) - } - } - } - - // Check that a packet distribution is as expected. - for ep, i := range eps { - wantRatio := wantDistribution[i] - wantRecv := wantRatio * float64(npackets) - actualRecv := stats[ep] - actualRatio := float64(stats[ep]) / float64(npackets) - // The deviation is less than 10%. - if math.Abs(actualRatio-wantRatio) > 0.05 { - t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets) - } - } - }) - } - } - } -} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go deleted file mode 100644 index 839178809..000000000 --- a/pkg/tcpip/stack/transport_test.go +++ /dev/null @@ -1,555 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack_test - -import ( - "bytes" - "io" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/ports" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeTransHeaderLen int = 3 -) - -// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts -// received packets; the counts of all endpoints are aggregated in the protocol -// descriptor. -// -// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't -// use it. -type fakeTransportEndpoint struct { - stack.TransportEndpointInfo - tcpip.DefaultSocketOptionsHandler - - proto *fakeTransportProtocol - peerAddr tcpip.Address - route *stack.Route - uniqueID uint64 - - // acceptQueue is non-nil iff bound. - acceptQueue []*fakeTransportEndpoint - - // ops is used to set and get socket options. - ops tcpip.SocketOptions -} - -func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { - return &f.TransportEndpointInfo -} - -func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats { - return nil -} - -func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} - -func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { - return &f.ops -} - -func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint { - ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()} - ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - return ep -} - -func (f *fakeTransportEndpoint) Abort() { - f.Close() -} - -func (f *fakeTransportEndpoint) Close() { - // TODO(gvisor.dev/issue/5153): Consider retaining the route. - f.route.Release() -} - -func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask { - return mask -} - -func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { - return tcpip.ReadResult{}, nil -} - -func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - if len(f.route.RemoteAddress()) == 0 { - return 0, &tcpip.ErrNoRoute{} - } - - v := make([]byte, p.Len()) - if _, err := io.ReadFull(p, v); err != nil { - return 0, &tcpip.ErrBadBuffer{} - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, - Data: buffer.View(v).ToVectorisedView(), - }) - _ = pkt.TransportHeader().Push(fakeTransHeaderLen) - if err := f.route.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { - return 0, err - } - - return int64(len(v)), nil -} - -// SetSockOpt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { - return &tcpip.ErrInvalidEndpointState{} -} - -// SetSockOptInt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { - return &tcpip.ErrInvalidEndpointState{} -} - -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { - return -1, &tcpip.ErrUnknownProtocolOption{} -} - -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { - return &tcpip.ErrInvalidEndpointState{} -} - -// Disconnect implements tcpip.Endpoint.Disconnect. -func (*fakeTransportEndpoint) Disconnect() tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) tcpip.Error { - f.peerAddr = addr.Addr - - // Find the route. - r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) - if err != nil { - return &tcpip.ErrNoRoute{} - } - - // Try to register so that we can start receiving packets. - f.ID.RemoteAddress = addr.Addr - err = f.proto.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) - if err != nil { - r.Release() - return err - } - - f.route = r - - return nil -} - -func (f *fakeTransportEndpoint) UniqueID() uint64 { - return f.uniqueID -} - -func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) tcpip.Error { - return nil -} - -func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { - return nil -} - -func (*fakeTransportEndpoint) Reset() { -} - -func (*fakeTransportEndpoint) Listen(int) tcpip.Error { - return nil -} - -func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { - if len(f.acceptQueue) == 0 { - return nil, nil, nil - } - a := f.acceptQueue[0] - f.acceptQueue = f.acceptQueue[1:] - return a, nil, nil -} - -func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) tcpip.Error { - if err := f.proto.stack.RegisterTransportEndpoint( - []tcpip.NetworkProtocolNumber{fakeNetNumber}, - fakeTransNumber, - stack.TransportEndpointID{LocalAddress: a.Addr}, - f, - ports.Flags{}, - 0, /* bindtoDevice */ - ); err != nil { - return err - } - f.acceptQueue = []*fakeTransportEndpoint{} - return nil -} - -func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { - return tcpip.FullAddress{}, nil -} - -func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { - return tcpip.FullAddress{}, nil -} - -func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { - // Increment the number of received packets. - f.proto.packetCount++ - if f.acceptQueue == nil { - return - } - - netHdr := pkt.NetworkHeader().View() - route, err := f.proto.stack.FindRoute(pkt.NICID, tcpip.Address(netHdr[dstAddrOffset]), tcpip.Address(netHdr[srcAddrOffset]), pkt.NetworkProtocolNumber, false /* multicastLoop */) - if err != nil { - return - } - - ep := &fakeTransportEndpoint{ - TransportEndpointInfo: stack.TransportEndpointInfo{ - ID: f.ID, - NetProto: f.NetProto, - }, - proto: f.proto, - peerAddr: route.RemoteAddress(), - route: route, - } - ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - f.acceptQueue = append(f.acceptQueue, ep) -} - -func (f *fakeTransportEndpoint) HandleError(stack.TransportError, *stack.PacketBuffer) { - // Increment the number of received control packets. - f.proto.controlCount++ -} - -func (*fakeTransportEndpoint) State() uint32 { - return 0 -} - -func (*fakeTransportEndpoint) ModerateRecvBuf(copied int) {} - -func (*fakeTransportEndpoint) Resume(*stack.Stack) {} - -func (*fakeTransportEndpoint) Wait() {} - -func (*fakeTransportEndpoint) LastError() tcpip.Error { - return nil -} - -type fakeTransportGoodOption bool - -type fakeTransportBadOption bool - -type fakeTransportInvalidValueOption int - -type fakeTransportProtocolOptions struct { - good bool -} - -// fakeTransportProtocol is a transport-layer protocol descriptor. It -// aggregates the number of packets received via endpoints of this protocol. -type fakeTransportProtocol struct { - stack *stack.Stack - - packetCount int - controlCount int - opts fakeTransportProtocolOptions -} - -func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { - return fakeTransNumber -} - -func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { - return newFakeTransportEndpoint(f, netProto, f.stack), nil -} - -func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { - return nil, &tcpip.ErrUnknownProtocol{} -} - -func (*fakeTransportProtocol) MinimumPacketSize() int { - return fakeTransHeaderLen -} - -func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err tcpip.Error) { - return 0, 0, nil -} - -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { - return stack.UnknownDestinationPacketHandled -} - -func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error { - switch v := option.(type) { - case *tcpip.TCPModerateReceiveBufferOption: - f.opts.good = bool(*v) - return nil - default: - return &tcpip.ErrUnknownProtocolOption{} - } -} - -func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error { - switch v := option.(type) { - case *tcpip.TCPModerateReceiveBufferOption: - *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good) - return nil - default: - return &tcpip.ErrUnknownProtocolOption{} - } -} - -// Abort implements TransportProtocol.Abort. -func (*fakeTransportProtocol) Abort() {} - -// Close implements tcpip.Endpoint.Close. -func (*fakeTransportProtocol) Close() {} - -// Wait implements TransportProtocol.Wait. -func (*fakeTransportProtocol) Wait() {} - -// Parse implements TransportProtocol.Parse. -func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { - _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen) - return ok -} - -func fakeTransFactory(s *stack.Stack) stack.TransportProtocol { - return &fakeTransportProtocol{stack: s} -} - -func TestTransportReceive(t *testing.T) { - linkEP := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - if err := s.CreateNIC(1, linkEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - // Create endpoint and connect to remote address. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) - - // Create buffer that will hold the packet. - buf := buffer.NewView(30) - - // Make sure packet with wrong protocol is not delivered. - buf[0] = 1 - buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.packetCount != 0 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) - } - - // Make sure packet from the wrong source is not delivered. - buf[0] = 1 - buf[1] = 3 - buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.packetCount != 0 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) - } - - // Make sure packet is delivered. - buf[0] = 1 - buf[1] = 2 - buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.packetCount != 1 { - t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) - } -} - -func TestTransportControlReceive(t *testing.T) { - linkEP := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - if err := s.CreateNIC(1, linkEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - // Create endpoint and connect to remote address. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) - - // Create buffer that will hold the control packet. - buf := buffer.NewView(2*fakeNetHeaderLen + 30) - - // Outer packet contains the control protocol number. - buf[0] = 1 - buf[1] = 0xfe - buf[2] = uint8(fakeControlProtocol) - - // Make sure packet with wrong protocol is not delivered. - buf[fakeNetHeaderLen+0] = 0 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.controlCount != 0 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) - } - - // Make sure packet from the wrong source is not delivered. - buf[fakeNetHeaderLen+0] = 3 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.controlCount != 0 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) - } - - // Make sure packet is delivered. - buf[fakeNetHeaderLen+0] = 2 - buf[fakeNetHeaderLen+1] = 1 - buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if fakeTrans.controlCount != 1 { - t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) - } -} - -func TestTransportSend(t *testing.T) { - linkEP := channel.New(10, defaultMTU, "") - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - if err := s.CreateNIC(1, linkEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - // Create endpoint and bind it. - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { - t.Fatalf("Connect failed: %v", err) - } - - // Create buffer that will hold the payload. - b := make([]byte, 30) - var r bytes.Reader - r.Reset(b) - if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("write failed: %v", err) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - if fakeNet.sendPacketCount[2] != 1 { - t.Errorf("sendPacketCount = %d, want %d", fakeNet.sendPacketCount[2], 1) - } -} - -func TestTransportOptions(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - - v := tcpip.TCPModerateReceiveBufferOption(true) - if err := s.SetTransportProtocolOption(fakeTransNumber, &v); err != nil { - t.Errorf("s.SetTransportProtocolOption(fakeTrans, &%T(%t)): %s", v, v, err) - } - v = false - if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &%T): %s", v, err) - } - if !v { - t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true") - } -} diff --git a/pkg/tcpip/stack/tuple_list.go b/pkg/tcpip/stack/tuple_list.go new file mode 100644 index 000000000..31d0feefa --- /dev/null +++ b/pkg/tcpip/stack/tuple_list.go @@ -0,0 +1,221 @@ +package stack + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type tupleElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (tupleElementMapper) linkerFor(elem *tuple) *tuple { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type tupleList struct { + head *tuple + tail *tuple +} + +// Reset resets list l to the empty state. +func (l *tupleList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *tupleList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *tupleList) Front() *tuple { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *tupleList) Back() *tuple { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *tupleList) Len() (count int) { + for e := l.Front(); e != nil; e = (tupleElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *tupleList) PushFront(e *tuple) { + linker := tupleElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + tupleElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *tupleList) PushBack(e *tuple) { + linker := tupleElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + tupleElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *tupleList) PushBackList(m *tupleList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + tupleElementMapper{}.linkerFor(l.tail).SetNext(m.head) + tupleElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *tupleList) InsertAfter(b, e *tuple) { + bLinker := tupleElementMapper{}.linkerFor(b) + eLinker := tupleElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + tupleElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *tupleList) InsertBefore(a, e *tuple) { + aLinker := tupleElementMapper{}.linkerFor(a) + eLinker := tupleElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + tupleElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *tupleList) Remove(e *tuple) { + linker := tupleElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + tupleElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + tupleElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type tupleEntry struct { + next *tuple + prev *tuple +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *tupleEntry) Next() *tuple { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *tupleEntry) Prev() *tuple { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *tupleEntry) SetNext(elem *tuple) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *tupleEntry) SetPrev(elem *tuple) { + e.prev = elem +} |