diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/BUILD | 15 | ||||
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 644 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache.go | 359 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache_test.go | 291 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 881 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache.go | 156 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache_test.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry.go | 398 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry_test.go | 575 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighborstate_string.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 258 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/nud_test.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/stack/pending_packets.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 104 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 105 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 4 |
21 files changed, 1619 insertions, 2262 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index ee23c9b98..49362333a 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -4,18 +4,6 @@ load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) go_template_instance( - name = "linkaddrentry_list", - out = "linkaddrentry_list.go", - package = "stack", - prefix = "linkAddrEntry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*linkAddrEntry", - "Linker": "*linkAddrEntry", - }, -) - -go_template_instance( name = "neighbor_entry_list", out = "neighbor_entry_list.go", package = "stack", @@ -62,8 +50,6 @@ go_library( "iptables_state.go", "iptables_targets.go", "iptables_types.go", - "linkaddrcache.go", - "linkaddrentry_list.go", "neighbor_cache.go", "neighbor_entry.go", "neighbor_entry_list.go", @@ -141,7 +127,6 @@ go_test( size = "small", srcs = [ "forwarding_test.go", - "linkaddrcache_test.go", "neighbor_cache_test.go", "neighbor_entry_test.go", "nic_test.go", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 54617f2e6..cdb435644 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -231,6 +231,12 @@ func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { return &conn } +func (ct *ConnTrack) init() { + ct.mu.Lock() + defer ct.mu.Unlock() + ct.buckets = make([]bucket, numBuckets) +} + // connFor gets the conn for pkt if it exists, or returns nil // if it does not. It returns an error when pkt does not contain a valid TCP // header. diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index c24f56ece..c987c1851 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -75,6 +75,10 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { } func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { + if _, _, ok := f.proto.Parse(pkt); !ok { + return + } + netHdr := pkt.NetworkHeader().View() _, dst := f.proto.ParseAddresses(netHdr) @@ -161,9 +165,9 @@ var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) type fwdTestNetworkProtocol struct { stack *Stack - neighborTable neighborTable + neigh *neighborCache addrResolveDelay time.Duration - onLinkAddressResolved func(neighborTable, tcpip.Address, tcpip.LinkAddress) + onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) mu struct { @@ -221,7 +225,7 @@ func (*fwdTestNetworkProtocol) Wait() {} func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { if fn := f.proto.onLinkAddressResolved; fn != nil { time.AfterFunc(f.proto.addrResolveDelay, func() { - fn(f.proto.neighborTable, addr, remoteLinkAddr) + fn(f.proto.neigh, addr, remoteLinkAddr) }) } return nil @@ -357,14 +361,13 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco panic("not implemented") } -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { proto.stack = s return proto }}, - UseNeighborCache: useNeighborCache, }) // Enable forwarding. @@ -402,7 +405,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC } if l, ok := nic.linkAddrResolvers[fwdTestNetNumber]; ok { - proto.neighborTable = l.neighborTable + proto.neigh = &l.neigh } // Route all packets to NIC 2. @@ -418,121 +421,85 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC } func TestForwardingWithStaticResolver(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + // Create a network protocol with a static resolver. + proto := &fwdTestNetworkProtocol{ + onResolveStaticAddress: + // The network address 3 is resolved to the link address "c". + func(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "\x03" { + return "c", true + } + return "", false }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - // Create a network protocol with a static resolver. - proto := &fwdTestNetworkProtocol{ - onResolveStaticAddress: - // The network address 3 is resolved to the link address "c". - func(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "\x03" { - return "c", true - } - return "", false - }, - } - - ep1, ep2 := fwdTestNetFactory(t, proto, test.useNeighborCache) + ep1, ep2 := fwdTestNetFactory(t, proto) - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - var p fwdTestPacketInfo + var p fwdTestPacketInfo - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } + select { + case p = <-ep2.C: + default: + t.Fatal("packet not forwarded") + } - // Test that the static address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - }) + // Test that the static address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) } } func TestForwardingWithFakeResolver(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + proto := fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + t.Helper() + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) + } + // Any address will be resolved to the link address "c". + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) }, } + ep1, ep2 := fwdTestNetFactory(t, &proto) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any address will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - }) + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) } } @@ -542,7 +509,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // Whether or not we use the neighbor cache here does not matter since // neither linkAddrCache nor neighborCache will be used. - ep1, ep2 := fwdTestNetFactory(t, proto, false /* useNeighborCache */) + ep1, ep2 := fwdTestNetFactory(t, proto) // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -562,12 +529,12 @@ func TestForwardingWithNoResolver(t *testing.T) { func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { proto := &fwdTestNetworkProtocol{ addrResolveDelay: 50 * time.Millisecond, - onLinkAddressResolved: func(neighborTable, tcpip.Address, tcpip.LinkAddress) { + onLinkAddressResolved: func(*neighborCache, tcpip.Address, tcpip.LinkAddress) { // Don't resolve the link address. }, } - ep1, ep2 := fwdTestNetFactory(t, proto, true /* useNeighborCache */) + ep1, ep2 := fwdTestNetFactory(t, proto) const numPackets int = 5 // These packets will all be enqueued in the packet queue to wait for link @@ -592,300 +559,227 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { } func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + proto := fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + t.Helper() + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) + } + // Only packets to address 3 will be resolved to the + // link address "c". + if addr == "\x03" { + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + } }, } + ep1, ep2 := fwdTestNetFactory(t, &proto) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - } - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) - - // Inject an inbound packet to address 4 on NIC 1. This packet should - // not be forwarded. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 4 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf = buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } + // Inject an inbound packet to address 4 on NIC 1. This packet should + // not be forwarded. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 4 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf = buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - }) + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) } } func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + proto := fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + t.Helper() + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) }, } + ep1, ep2 := fwdTestNetFactory(t, &proto) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) - - // Inject two inbound packets to address 3 on NIC 1. - for i := 0; i < 2; i++ { - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } + // Inject two inbound packets to address 3 on NIC 1. + for i := 0; i < 2; i++ { + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - for i := 0; i < 2; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } - }) + for i := 0; i < 2; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } } } func TestForwardingWithFakeResolverManyPackets(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + proto := fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + t.Helper() + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) }, } + ep1, ep2 := fwdTestNetFactory(t, &proto) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) - - for i := 0; i < maxPendingPacketsPerResolution+5; i++ { - // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - // Set the packet sequence number. - binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } + for i := 0; i < maxPendingPacketsPerResolution+5; i++ { + // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + // Set the packet sequence number. + binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - for i := 0; i < maxPendingPacketsPerResolution; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - b := PayloadSince(p.Pkt.NetworkHeader()) - if b[dstAddrOffset] != 3 { - t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) - } - if len(b) < fwdTestNetHeaderLen+2 { - t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) - } - seqNumBuf := b[fwdTestNetHeaderLen:] - - // The first 5 packets should not be forwarded so the sequence number should - // start with 5. - want := uint16(i + 5) - if n := binary.BigEndian.Uint16(seqNumBuf); n != want { - t.Fatalf("got the packet #%d, want = #%d", n, want) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } - }) + for i := 0; i < maxPendingPacketsPerResolution; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := PayloadSince(p.Pkt.NetworkHeader()) + if b[dstAddrOffset] != 3 { + t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) + } + if len(b) < fwdTestNetHeaderLen+2 { + t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) + } + seqNumBuf := b[fwdTestNetHeaderLen:] + + // The first 5 packets should not be forwarded so the sequence number should + // start with 5. + want := uint16(i + 5) + if n := binary.BigEndian.Uint16(seqNumBuf); n != want { + t.Fatalf("got the packet #%d, want = #%d", n, want) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } } } func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - proto *fwdTestNetworkProtocol - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + proto := fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + t.Helper() + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) }, } + ep1, ep2 := fwdTestNetFactory(t, &proto) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) - - for i := 0; i < maxPendingResolutions+5; i++ { - // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. - // Each packet has a different destination address (3 to - // maxPendingResolutions + 7). - buf := buffer.NewView(30) - buf[dstAddrOffset] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } + for i := 0; i < maxPendingResolutions+5; i++ { + // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. + // Each packet has a different destination address (3 to + // maxPendingResolutions + 7). + buf := buffer.NewView(30) + buf[dstAddrOffset] = byte(3 + i) + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - for i := 0; i < maxPendingResolutions; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - // The first 5 packets (address 3 to 7) should not be forwarded - // because their address resolutions are interrupted. - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } - }) + for i := 0; i < maxPendingResolutions; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // The first 5 packets (address 3 to 7) should not be forwarded + // because their address resolutions are interrupted. + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } } } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 63832c200..52890f6eb 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -235,7 +235,7 @@ func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) tcpip.Error // If iptables is being enabled, initialize the conntrack table and // reaper. if !it.modified { - it.connections.buckets = make([]bucket, numBuckets) + it.connections.init() it.startReaper(reaperDelay) } it.modified = true diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go deleted file mode 100644 index 5b6b58b1d..000000000 --- a/pkg/tcpip/stack/linkaddrcache.go +++ /dev/null @@ -1,359 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" -) - -const linkAddrCacheSize = 512 // max cache entries - -// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses. -// -// The entries are stored in a ring buffer, oldest entry replaced first. -// -// This struct is safe for concurrent use. -type linkAddrCache struct { - nic *NIC - - linkRes LinkAddressResolver - - // ageLimit is how long a cache entry is valid for. - ageLimit time.Duration - - // resolutionTimeout is the amount of time to wait for a link request to - // resolve an address. - resolutionTimeout time.Duration - - // resolutionAttempts is the number of times an address is attempted to be - // resolved before failing. - resolutionAttempts int - - mu struct { - sync.Mutex - table map[tcpip.Address]*linkAddrEntry - lru linkAddrEntryList - } -} - -// entryState controls the state of a single entry in the cache. -type entryState int - -const ( - // incomplete means that there is an outstanding request to resolve the - // address. This is the initial state. - incomplete entryState = iota - // ready means that the address has been resolved and can be used. - ready -) - -// String implements Stringer. -func (s entryState) String() string { - switch s { - case incomplete: - return "incomplete" - case ready: - return "ready" - default: - return fmt.Sprintf("unknown(%d)", s) - } -} - -// A linkAddrEntry is an entry in the linkAddrCache. -// This struct is thread-compatible. -type linkAddrEntry struct { - // linkAddrEntryEntry access is synchronized by the linkAddrCache lock. - linkAddrEntryEntry - - cache *linkAddrCache - - mu struct { - sync.RWMutex - - addr tcpip.Address - linkAddr tcpip.LinkAddress - expiration time.Time - s entryState - - // done is closed when address resolution is complete. It is nil iff s is - // incomplete and resolution is not yet in progress. - done chan struct{} - - // onResolve is called with the result of address resolution. - onResolve []func(LinkResolutionResult) - } -} - -func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { - res := LinkResolutionResult{LinkAddress: linkAddr, Success: len(linkAddr) != 0} - for _, callback := range e.mu.onResolve { - callback(res) - } - e.mu.onResolve = nil - if ch := e.mu.done; ch != nil { - close(ch) - e.mu.done = nil - // Dequeue the pending packets in a new goroutine to not hold up the current - // goroutine as writing packets may be a costly operation. - // - // At the time of writing, when writing packets, a neighbor's link address - // is resolved (which ends up obtaining the entry's lock) while holding the - // link resolution queue's lock. Dequeuing packets in a new goroutine avoids - // a lock ordering violation. - go e.cache.nic.linkResQueue.dequeue(ch, linkAddr, len(linkAddr) != 0) - } -} - -// changeStateLocked sets the entry's state to ns. -// -// The entry's expiration is bumped up to the greater of itself and the passed -// expiration; the zero value indicates immediate expiration, and is set -// unconditionally - this is an implementation detail that allows for entries -// to be reused. -// -// Precondition: e.mu must be locked -func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { - if e.mu.s == incomplete && ns == ready { - e.notifyCompletionLocked(e.mu.linkAddr) - } - - if expiration.IsZero() || expiration.After(e.mu.expiration) { - e.mu.expiration = expiration - } - e.mu.s = ns -} - -// add adds a k -> v mapping to the cache. -func (c *linkAddrCache) add(k tcpip.Address, v tcpip.LinkAddress) { - // Calculate expiration time before acquiring the lock, since expiration is - // relative to the time when information was learned, rather than when it - // happened to be inserted into the cache. - expiration := time.Now().Add(c.ageLimit) - - c.mu.Lock() - entry := c.getOrCreateEntryLocked(k) - entry.mu.Lock() - defer entry.mu.Unlock() - c.mu.Unlock() - - entry.mu.linkAddr = v - entry.changeStateLocked(ready, expiration) -} - -// getOrCreateEntryLocked retrieves a cache entry associated with k. The -// returned entry is always refreshed in the cache (it is reachable via the -// map, and its place is bumped in LRU). -// -// If a matching entry exists in the cache, it is returned. If no matching -// entry exists and the cache is full, an existing entry is evicted via LRU, -// reset to state incomplete, and returned. If no matching entry exists and the -// cache is not full, a new entry with state incomplete is allocated and -// returned. -func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { - if entry, ok := c.mu.table[k]; ok { - c.mu.lru.Remove(entry) - c.mu.lru.PushFront(entry) - return entry - } - var entry *linkAddrEntry - if len(c.mu.table) == linkAddrCacheSize { - entry = c.mu.lru.Back() - entry.mu.Lock() - - delete(c.mu.table, entry.mu.addr) - c.mu.lru.Remove(entry) - - // Wake waiters and mark the soon-to-be-reused entry as expired. - entry.notifyCompletionLocked("" /* linkAddr */) - entry.mu.Unlock() - } else { - entry = new(linkAddrEntry) - } - - *entry = linkAddrEntry{ - cache: c, - } - entry.mu.Lock() - entry.mu.addr = k - entry.mu.s = incomplete - entry.mu.Unlock() - c.mu.table[k] = entry - c.mu.lru.PushFront(entry) - return entry -} - -// get reports any known link address for addr. -func (c *linkAddrCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - c.mu.Lock() - entry := c.getOrCreateEntryLocked(addr) - entry.mu.Lock() - defer entry.mu.Unlock() - c.mu.Unlock() - - switch s := entry.mu.s; s { - case ready: - if !time.Now().After(entry.mu.expiration) { - // Not expired. - if onResolve != nil { - onResolve(LinkResolutionResult{LinkAddress: entry.mu.linkAddr, Success: true}) - } - return entry.mu.linkAddr, nil, nil - } - - entry.changeStateLocked(incomplete, time.Time{}) - fallthrough - case incomplete: - if onResolve != nil { - entry.mu.onResolve = append(entry.mu.onResolve, onResolve) - } - if entry.mu.done == nil { - entry.mu.done = make(chan struct{}) - go c.startAddressResolution(addr, localAddr, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. - } - return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{} - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) - } -} - -func (c *linkAddrCache) startAddressResolution(k tcpip.Address, localAddr tcpip.Address, done <-chan struct{}) { - for i := 0; ; i++ { - // Send link request, then wait for the timeout limit and check - // whether the request succeeded. - c.linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */) - - select { - case now := <-time.After(c.resolutionTimeout): - if stop := c.checkLinkRequest(now, k, i); stop { - return - } - case <-done: - return - } - } -} - -// checkLinkRequest checks whether previous attempt to resolve address has -// succeeded and mark the entry accordingly. Returns true if request can stop, -// false if another request should be sent. -func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt int) bool { - c.mu.Lock() - defer c.mu.Unlock() - entry, ok := c.mu.table[k] - if !ok { - // Entry was evicted from the cache. - return true - } - entry.mu.Lock() - defer entry.mu.Unlock() - - switch s := entry.mu.s; s { - case ready: - // Entry was made ready by resolver. - case incomplete: - if attempt+1 < c.resolutionAttempts { - // No response yet, need to send another ARP request. - return false - } - // Max number of retries reached, delete entry. - entry.notifyCompletionLocked("" /* linkAddr */) - delete(c.mu.table, k) - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) - } - return true -} - -func (c *linkAddrCache) init(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int, linkRes LinkAddressResolver) { - *c = linkAddrCache{ - nic: nic, - linkRes: linkRes, - ageLimit: ageLimit, - resolutionTimeout: resolutionTimeout, - resolutionAttempts: resolutionAttempts, - } - - c.mu.Lock() - c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize) - c.mu.Unlock() -} - -var _ neighborTable = (*linkAddrCache)(nil) - -func (*linkAddrCache) neighbors() ([]NeighborEntry, tcpip.Error) { - return nil, &tcpip.ErrNotSupported{} -} - -func (c *linkAddrCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) { - c.add(addr, linkAddr) -} - -func (*linkAddrCache) remove(addr tcpip.Address) tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -func (*linkAddrCache) removeAll() tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -func (c *linkAddrCache) handleProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress) { - if len(linkAddr) != 0 { - // NUD allows probes without a link address but linkAddrCache - // is a simple neighbor table which does not implement NUD. - // - // As per RFC 4861 section 4.3, - // - // Source link-layer address - // The link-layer address for the sender. MUST NOT be - // included when the source IP address is the - // unspecified address. Otherwise, on link layers - // that have addresses this option MUST be included in - // multicast solicitations and SHOULD be included in - // unicast solicitations. - c.add(addr, linkAddr) - } -} - -func (c *linkAddrCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { - if len(linkAddr) != 0 { - // NUD allows confirmations without a link address but linkAddrCache - // is a simple neighbor table which does not implement NUD. - // - // As per RFC 4861 section 4.4, - // - // Target link-layer address - // The link-layer address for the target, i.e., the - // sender of the advertisement. This option MUST be - // included on link layers that have addresses when - // responding to multicast solicitations. When - // responding to a unicast Neighbor Solicitation this - // option SHOULD be included. - c.add(addr, linkAddr) - } -} - -func (c *linkAddrCache) handleUpperLevelConfirmation(tcpip.Address) {} - -func (*linkAddrCache) nudConfig() (NUDConfigurations, tcpip.Error) { - return NUDConfigurations{}, &tcpip.ErrNotSupported{} -} - -func (*linkAddrCache) setNUDConfig(NUDConfigurations) tcpip.Error { - return &tcpip.ErrNotSupported{} -} diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go deleted file mode 100644 index 9e7f331c9..000000000 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ /dev/null @@ -1,291 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - "math" - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" -) - -type testaddr struct { - addr tcpip.Address - linkAddr tcpip.LinkAddress -} - -var testAddrs = func() []testaddr { - var addrs []testaddr - for i := 0; i < 4*linkAddrCacheSize; i++ { - addr := fmt.Sprintf("Addr%06d", i) - addrs = append(addrs, testaddr{ - addr: tcpip.Address(addr), - linkAddr: tcpip.LinkAddress("Link" + addr), - }) - } - return addrs -}() - -type testLinkAddressResolver struct { - cache *linkAddrCache - delay time.Duration - onLinkAddressRequest func() -} - -func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error { - // TODO(gvisor.dev/issue/5141): Use a fake clock. - time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) - if f := r.onLinkAddressRequest; f != nil { - f() - } - return nil -} - -func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { - for _, ta := range testAddrs { - if ta.addr == addr { - r.cache.add(ta.addr, ta.linkAddr) - break - } - } -} - -func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "broadcast" { - return "mac_broadcast", true - } - return "", false -} - -func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return 1 -} - -func getBlocking(c *linkAddrCache, addr tcpip.Address) (tcpip.LinkAddress, tcpip.Error) { - var attemptedResolution bool - for { - got, ch, err := c.get(addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - if attemptedResolution { - return got, &tcpip.ErrTimeout{} - } - attemptedResolution = true - <-ch - continue - } - return got, err - } -} - -func newEmptyNIC() *NIC { - n := &NIC{} - n.linkResQueue.init(n) - return n -} - -func TestCacheOverflow(t *testing.T) { - var c linkAddrCache - c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) - for i := len(testAddrs) - 1; i >= 0; i-- { - e := testAddrs[i] - c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, "", nil) - if err != nil { - t.Errorf("insert %d, c.get(%s, '', nil): %s", i, e.addr, err) - } - if got != e.linkAddr { - t.Errorf("insert %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) - } - } - // Expect to find at least half of the most recent entries. - for i := 0; i < linkAddrCacheSize/2; i++ { - e := testAddrs[i] - got, _, err := c.get(e.addr, "", nil) - if err != nil { - t.Errorf("check %d, c.get(%s, '', nil): %s", i, e.addr, err) - } - if got != e.linkAddr { - t.Errorf("check %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) - } - } - // The earliest entries should no longer be in the cache. - c.mu.Lock() - defer c.mu.Unlock() - for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { - e := testAddrs[i] - if entry, ok := c.mu.table[e.addr]; ok { - t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry) - } - } -} - -func TestCacheConcurrent(t *testing.T) { - var c linkAddrCache - linkRes := &testLinkAddressResolver{cache: &c} - c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, linkRes) - - var wg sync.WaitGroup - for r := 0; r < 16; r++ { - wg.Add(1) - go func() { - for _, e := range testAddrs { - c.add(e.addr, e.linkAddr) - } - wg.Done() - }() - } - wg.Wait() - - // All goroutines add in the same order and add more values than - // can fit in the cache, so our eviction strategy requires that - // the last entry be present and the first be missing. - e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, "", nil) - if err != nil { - t.Errorf("c.get(%s, '', nil): %s", e.addr, err) - } - if got != e.linkAddr { - t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) - } - - e = testAddrs[0] - c.mu.Lock() - defer c.mu.Unlock() - if entry, ok := c.mu.table[e.addr]; ok { - t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry) - } -} - -func TestCacheAgeLimit(t *testing.T) { - var c linkAddrCache - linkRes := &testLinkAddressResolver{cache: &c} - c.init(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3, linkRes) - - e := testAddrs[0] - c.add(e.addr, e.linkAddr) - time.Sleep(50 * time.Millisecond) - _, _, err := c.get(e.addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.get(%s, '', nil) = %s, want = ErrWouldBlock", e.addr, err) - } -} - -func TestCacheReplace(t *testing.T) { - var c linkAddrCache - c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) - e := testAddrs[0] - l2 := e.linkAddr + "2" - c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, "", nil) - if err != nil { - t.Errorf("c.get(%s, '', nil): %s", e.addr, err) - } - if got != e.linkAddr { - t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) - } - - c.add(e.addr, l2) - got, _, err = c.get(e.addr, "", nil) - if err != nil { - t.Errorf("c.get(%s, '', nil): %s", e.addr, err) - } - if got != l2 { - t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, l2) - } -} - -func TestCacheResolution(t *testing.T) { - // There is a race condition causing this test to fail when the executor - // takes longer than the resolution timeout to call linkAddrCache.get. This - // is especially common when this test is run with gotsan. - // - // Using a large resolution timeout decreases the probability of experiencing - // this race condition and does not affect how long this test takes to run. - var c linkAddrCache - linkRes := &testLinkAddressResolver{cache: &c} - c.init(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1, linkRes) - for i, ta := range testAddrs { - got, err := getBlocking(&c, ta.addr) - if err != nil { - t.Errorf("check %d, getBlocking(_, %s): %s", i, ta.addr, err) - } - if got != ta.linkAddr { - t.Errorf("check %d, got getBlocking(_, %s) = %s, want = %s", i, ta.addr, got, ta.linkAddr) - } - } - - // Check that after resolved, address stays in the cache and never returns WouldBlock. - for i := 0; i < 10; i++ { - e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, "", nil) - if err != nil { - t.Errorf("c.get(%s, '', nil): %s", e.addr, err) - } - if got != e.linkAddr { - t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) - } - } -} - -func TestCacheResolutionFailed(t *testing.T) { - var c linkAddrCache - linkRes := &testLinkAddressResolver{cache: &c} - c.init(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5, linkRes) - - var requestCount uint32 - linkRes.onLinkAddressRequest = func() { - atomic.AddUint32(&requestCount, 1) - } - - // First, sanity check that resolution is working... - e := testAddrs[0] - got, err := getBlocking(&c, e.addr) - if err != nil { - t.Errorf("getBlocking(_, %s): %s", e.addr, err) - } - if got != e.linkAddr { - t.Errorf("got getBlocking(_, %s) = %s, want = %s", e.addr, got, e.linkAddr) - } - - before := atomic.LoadUint32(&requestCount) - - e.addr += "2" - a, err := getBlocking(&c, e.addr) - if _, ok := err.(*tcpip.ErrTimeout); !ok { - t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) - } - - if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { - t.Errorf("got link address request count = %d, want = %d", got, want) - } -} - -func TestCacheResolutionTimeout(t *testing.T) { - resolverDelay := 500 * time.Millisecond - expiration := resolverDelay / 10 - var c linkAddrCache - linkRes := &testLinkAddressResolver{cache: &c, delay: resolverDelay} - c.init(newEmptyNIC(), expiration, 1*time.Millisecond, 3, linkRes) - - e := testAddrs[0] - a, err := getBlocking(&c, e.addr) - if _, ok := err.(*tcpip.ErrTimeout); !ok { - t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) - } -} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 0238605af..3b6ba9509 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -424,7 +424,7 @@ func TestDADResolve(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, - NDPConfigs: ipv6.NDPConfigurations{ + DADConfigs: stack.DADConfigurations{ RetransmitTimer: test.retransTimer, DupAddrDetectTransmits: test.dupAddrDetectTransmits, }, @@ -642,14 +642,14 @@ func TestDADFail(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } - ndpConfigs := ipv6.DefaultNDPConfigurations() - ndpConfigs.RetransmitTimer = time.Second * 2 + dadConfigs := stack.DefaultDADConfigurations() + dadConfigs.RetransmitTimer = time.Second * 2 e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, - NDPConfigs: ndpConfigs, + DADConfigs: dadConfigs, })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -677,7 +677,7 @@ func TestDADFail(t *testing.T) { // Wait for DAD to fail and make sure the address did // not get resolved. select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): // If we don't get a failure event after the // expected resolution time + extra 1s buffer, // something is wrong. @@ -748,7 +748,7 @@ func TestDADStop(t *testing.T) { dadC: make(chan ndpDADEvent, 1), } - ndpConfigs := ipv6.NDPConfigurations{ + dadConfigs := stack.DADConfigurations{ RetransmitTimer: time.Second, DupAddrDetectTransmits: 2, } @@ -757,7 +757,7 @@ func TestDADStop(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, - NDPConfigs: ndpConfigs, + DADConfigs: dadConfigs, })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -777,12 +777,12 @@ func TestDADStop(t *testing.T) { // Wait for DAD to fail (since the address was removed during DAD). select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): // If we don't get a failure event after the expected resolution // time + extra 1s buffer, something is wrong. t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { + if diff := checkDADEvent(e, nicID, addr1, false, &tcpip.ErrAborted{}); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } @@ -865,16 +865,15 @@ func TestSetNDPConfigurations(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) } - // Update the NDP configurations on NIC(1) to use DAD. - configs := ipv6.NDPConfigurations{ - DupAddrDetectTransmits: test.dupAddrDetectTransmits, - RetransmitTimer: test.retransmitTimer, - } + // Update the configurations on NIC(1) to use DAD. if ipv6Ep, err := s.GetNetworkEndpoint(nicID1, header.IPv6ProtocolNumber); err != nil { t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, header.IPv6ProtocolNumber, err) } else { - ndpEP := ipv6Ep.(ipv6.NDPEndpoint) - ndpEP.SetNDPConfigurations(configs) + dad := ipv6Ep.(stack.DuplicateAddressDetector) + dad.SetDADConfigurations(stack.DADConfigurations{ + DupAddrDetectTransmits: test.dupAddrDetectTransmits, + RetransmitTimer: test.retransmitTimer, + }) } // Created after updating NIC(1)'s NDP configurations @@ -1903,9 +1902,11 @@ func TestAutoGenTempAddr(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: test.dupAddrTransmits, + RetransmitTimer: test.retransmitTimer, + }, NDPConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: test.dupAddrTransmits, - RetransmitTimer: test.retransmitTimer, HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, @@ -2202,9 +2203,11 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, NDPConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, @@ -2635,16 +2638,15 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) - ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: test.tempAddrs, - AutoGenAddressConflictRetries: 1, - } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: test.tempAddrs, + AutoGenAddressConflictRetries: 1, + }, + NDPDisp: &ndpDisp, OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: test.nicNameFromID, }, @@ -2718,13 +2720,14 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // Enable DAD. ndpDisp.dadC = make(chan ndpDADEvent, 2) - ndpConfigs.DupAddrDetectTransmits = dupAddrTransmits - ndpConfigs.RetransmitTimer = retransmitTimer if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) } else { - ndpEP := ipv6Ep.(ipv6.NDPEndpoint) - ndpEP.SetNDPConfigurations(ndpConfigs) + ndpEP := ipv6Ep.(stack.DuplicateAddressDetector) + ndpEP.SetDADConfigurations(stack.DADConfigurations{ + DupAddrDetectTransmits: dupAddrTransmits, + RetransmitTimer: retransmitTimer, + }) } // Do SLAAC for prefix. @@ -2769,7 +2772,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // stack.Stack will have a default route through the router (llAddr3) installed // and a static link-address (linkAddr3) added to the link address cache for the // router. -func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useNeighborCache bool) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { t.Helper() ndpDisp := &ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), @@ -2785,7 +2788,6 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN NDPDisp: ndpDisp, })}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - UseNeighborCache: useNeighborCache, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -2868,126 +2870,108 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA // TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when // receiving a PI with 0 preferred lifetime. func TestAutoGenAddrDeprecateFromPI(t *testing.T) { - stacks := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - const nicID = 1 + const nicID = 1 - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected addr auto gen event") + } + } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) + } - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - // Receive PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - expectPrimaryAddr(addr1) + // Receive PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + expectPrimaryAddr(addr1) - // Deprecate addr for prefix1 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - // addr should still be the primary endpoint as there are no other addresses. - expectPrimaryAddr(addr1) + // Deprecate addr for prefix1 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + // addr should still be the primary endpoint as there are no other addresses. + expectPrimaryAddr(addr1) - // Refresh lifetimes of addr generated from prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) + // Refresh lifetimes of addr generated from prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Deprecate addr for prefix2 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr1 should be the primary endpoint now since addr2 is deprecated but - // addr1 is not. - expectPrimaryAddr(addr1) - // addr2 is deprecated but if explicitly requested, it should be used. - fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } + // Deprecate addr for prefix2 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr1 should be the primary endpoint now since addr2 is deprecated but + // addr1 is not. + expectPrimaryAddr(addr1) + // addr2 is deprecated but if explicitly requested, it should be used. + fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) + } - // Another PI w/ 0 preferred lifetime should not result in a deprecation - // event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } + // Another PI w/ 0 preferred lifetime should not result in a deprecation + // event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) + } - // Refresh lifetimes of addr generated from prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) - }) + // Refresh lifetimes of addr generated from prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: } + expectPrimaryAddr(addr2) } // TestAutoGenAddrJobDeprecation tests that an address is properly deprecated @@ -2997,232 +2981,214 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { const newMinVL = 2 newMinVLDuration := newMinVL * time.Second - stacks := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + saved := ipv6.MinPrefixInformationValidLifetimeForUpdate + defer func() { + ipv6.MinPrefixInformationValidLifetimeForUpdate = saved + }() + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected addr auto gen event") + } + } - expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() + expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(timeout): - t.Fatal("timed out waiting for addr auto gen event") - } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) + } - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Receive a PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr1) + // Receive a PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr1) - // Refresh lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) + // Refresh lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since addr1 is deprecated but - // addr2 is not. - expectPrimaryAddr(addr2) - // addr1 is deprecated but if explicitly requested, it should be used. - fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since addr1 is deprecated but + // addr2 is not. + expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. + fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make - // sure we do not get a deprecation event again. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make + // sure we do not get a deprecation event again. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Refresh lifetimes for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - // addr1 is the primary endpoint again since it is non-deprecated now. - expectPrimaryAddr(addr1) + // Refresh lifetimes for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + // addr1 is the primary endpoint again since it is non-deprecated now. + expectPrimaryAddr(addr1) - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since it is not deprecated. - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since it is not deprecated. + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Wait for addr of prefix1 to be invalidated. + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Refresh both lifetimes for addr of prefix2 to the same value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } + // Refresh both lifetimes for addr of prefix2 to the same value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } - // Wait for a deprecation then invalidation events, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation handlers could be handled in - // either deprecation then invalidation, or invalidation then deprecation - // (which should be cancelled by the invalidation handler). + // Wait for a deprecation then invalidation events, or just an invalidation + // event. We need to cover both cases but cannot deterministically hit both + // cases because the deprecation and invalidation handlers could be handled in + // either deprecation then invalidation, or invalidation then deprecation + // (which should be cancelled by the invalidation handler). + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { + // If we get a deprecation event first, we should get an invalidation + // event almost immediately after. select { case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we should not get a deprecation - // event after. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultAsyncNegativeEventTimeout): - } - } else { - t.Fatalf("got unexpected auto-generated event") + if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should not have %s in the list of addresses", addr2) - } - // Should not have any primary endpoints. - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { + // If we get an invalidation event first, we should not get a deprecation + // event after. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + case <-time.After(defaultAsyncNegativeEventTimeout): } - defer ep.Close() - ep.SocketOptions().SetV6Only(true) + } else { + t.Fatalf("got unexpected auto-generated event") + } + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should not have %s in the list of addresses", addr2) + } + // Should not have any primary endpoints. + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) + } + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + ep.SocketOptions().SetV6Only(true) - { - err := ep.Connect(dstAddr) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{}) - } - } - }) + { + err := ep.Connect(dstAddr) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{}) + } } } @@ -3543,126 +3509,108 @@ func TestAutoGenAddrRemoval(t *testing.T) { func TestAutoGenAddrAfterRemoval(t *testing.T) { const nicID = 1 - stacks := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } } - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + // Receive a PI to auto-generate addr1 with a large valid and preferred + // lifetime. + const largeLifetimeSeconds = 999 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr1, newAddr) + expectPrimaryAddr(addr1) - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } + // Add addr2 as a static address. + protoAddr2 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr2, + } + if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + } + // addr2 should be more preferred now since it is at the front of the primary + // list. + expectPrimaryAddr(addr2) - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + // Get a route using addr2 to increment its reference count then remove it + // to leave it in the permanentExpired state. + r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) + } + defer r.Release() + if err := s.RemoveAddress(nicID, addr2.Address); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) + } + // addr1 should be preferred again since addr2 is in the expired state. + expectPrimaryAddr(addr1) - // Receive a PI to auto-generate addr1 with a large valid and preferred - // lifetime. - const largeLifetimeSeconds = 999 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr1, newAddr) - expectPrimaryAddr(addr1) - - // Add addr2 as a static address. - protoAddr2 := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addr2, - } - if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) - } - // addr2 should be more preferred now since it is at the front of the primary - // list. - expectPrimaryAddr(addr2) - - // Get a route using addr2 to increment its reference count then remove it - // to leave it in the permanentExpired state. - r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) - } - defer r.Release() - if err := s.RemoveAddress(nicID, addr2.Address); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) - } - // addr1 should be preferred again since addr2 is in the expired state. - expectPrimaryAddr(addr1) - - // Receive a PI to auto-generate addr2 as valid and preferred. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr2 should be more preferred now that it is closer to the front of the - // primary list and not deprecated. - expectPrimaryAddr(addr2) - - // Removing the address should result in an invalidation event immediately. - // It should still be in the permanentExpired state because r is still held. - // - // We remove addr2 here to make sure addr2 was marked as a SLAAC address - // (it was previously marked as a static address). - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - // addr1 should be more preferred since addr2 is in the expired state. - expectPrimaryAddr(addr1) - - // Receive a PI to auto-generate addr2 as valid and deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr1 should still be more preferred since addr2 is deprecated, even though - // it is closer to the front of the primary list. - expectPrimaryAddr(addr1) - - // Receive a PI to refresh addr2's preferred lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto gen addr event") - default: - } - // addr2 should be more preferred now that it is not deprecated. - expectPrimaryAddr(addr2) + // Receive a PI to auto-generate addr2 as valid and preferred. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr2 should be more preferred now that it is closer to the front of the + // primary list and not deprecated. + expectPrimaryAddr(addr2) - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - expectPrimaryAddr(addr1) - }) + // Removing the address should result in an invalidation event immediately. + // It should still be in the permanentExpired state because r is still held. + // + // We remove addr2 here to make sure addr2 was marked as a SLAAC address + // (it was previously marked as a static address). + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + // addr1 should be more preferred since addr2 is in the expired state. + expectPrimaryAddr(addr1) + + // Receive a PI to auto-generate addr2 as valid and deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr1 should still be more preferred since addr2 is deprecated, even though + // it is closer to the front of the primary list. + expectPrimaryAddr(addr1) + + // Receive a PI to refresh addr2's preferred lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto gen addr event") + default: + } + // addr2 should be more preferred now that it is not deprecated. + expectPrimaryAddr(addr2) + + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + expectPrimaryAddr(addr1) } // TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that @@ -3885,12 +3833,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } } - expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) { + expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool, err tcpip.Error) { t.Helper() select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { + if diff := checkDADEvent(e, nicID, addr, resolved, err); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } default: @@ -3923,8 +3871,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, HandleRAs: true, AutoGenGlobalAddresses: true, }, @@ -3939,11 +3885,8 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { }, }, { - name: "LinkLocal address", - ndpConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, + name: "LinkLocal address", + ndpConfigs: ipv6.NDPConfigurations{}, autoGenLinkLocal: true, prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { return nil @@ -3955,8 +3898,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Temporary address", ndpConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, @@ -4008,8 +3949,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: func(_ tcpip.NICID, nicName string) string { return nicName @@ -4039,7 +3984,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // Simulate a DAD conflict. rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) - expectDADEvent(t, &ndpDisp, addr.Address, false) + expectDADEvent(t, &ndpDisp, addr.Address, false, nil) // Attempting to add the address manually should not fail if the // address's state was cleaned up when DAD failed. @@ -4049,7 +3994,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if err := s.RemoveAddress(nicID, addr.Address); err != nil { t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) } - expectDADEvent(t, &ndpDisp, addr.Address, false) + expectDADEvent(t, &ndpDisp, addr.Address, false, &tcpip.ErrAborted{}) } // Should not have any new addresses assigned to the NIC. @@ -4103,8 +4048,6 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, @@ -4119,8 +4062,6 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { { name: "LinkLocal address", ndpConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, AutoGenAddressConflictRetries: maxRetries, }, autoGenLinkLocal: true, @@ -4145,6 +4086,10 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { AutoGenLinkLocal: addrType.autoGenLinkLocal, NDPConfigs: addrType.ndpConfigs, NDPDisp: &ndpDisp, + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -4226,9 +4171,11 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, NDPConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 7e3132058..533287c4c 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -25,9 +25,13 @@ const neighborCacheSize = 512 // max entries per interface // NeighborStats holds metrics for the neighbor table. type NeighborStats struct { - // FailedEntryLookups counts the number of lookups performed on an entry in - // Failed state. + // FailedEntryLookups is deprecated; UnreachableEntryLookups should be used + // instead. FailedEntryLookups *tcpip.StatCounter + + // UnreachableEntryLookups counts the number of lookups performed on an + // entry in Unreachable state. + UnreachableEntryLookups *tcpip.StatCounter } // neighborCache maps IP addresses to link addresses. It uses the Least @@ -43,21 +47,22 @@ type NeighborStats struct { // Their state is always Static. The amount of static entries stored in the // cache is unbounded. type neighborCache struct { - nic *NIC + nic *nic state *NUDState linkRes LinkAddressResolver - // mu protects the fields below. - mu sync.RWMutex + mu struct { + sync.RWMutex - cache map[tcpip.Address]*neighborEntry - dynamic struct { - lru neighborEntryList + cache map[tcpip.Address]*neighborEntry + dynamic struct { + lru neighborEntryList - // count tracks the amount of dynamic entries in the cache. This is - // needed since static entries do not count towards the LRU cache - // eviction strategy. - count uint16 + // count tracks the amount of dynamic entries in the cache. This is + // needed since static entries do not count towards the LRU cache + // eviction strategy. + count uint16 + } } } @@ -74,11 +79,11 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address) *neighborEntr n.mu.Lock() defer n.mu.Unlock() - if entry, ok := n.cache[remoteAddr]; ok { + if entry, ok := n.mu.cache[remoteAddr]; ok { entry.mu.RLock() - if entry.neigh.State != Static { - n.dynamic.lru.Remove(entry) - n.dynamic.lru.PushFront(entry) + if entry.mu.neigh.State != Static { + n.mu.dynamic.lru.Remove(entry) + n.mu.dynamic.lru.PushFront(entry) } entry.mu.RUnlock() return entry @@ -87,20 +92,20 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address) *neighborEntr // The entry that needs to be created must be dynamic since all static // entries are directly added to the cache via addStaticEntry. entry := newNeighborEntry(n, remoteAddr, n.state) - if n.dynamic.count == neighborCacheSize { - e := n.dynamic.lru.Back() + if n.mu.dynamic.count == neighborCacheSize { + e := n.mu.dynamic.lru.Back() e.mu.Lock() - delete(n.cache, e.neigh.Addr) - n.dynamic.lru.Remove(e) - n.dynamic.count-- + delete(n.mu.cache, e.mu.neigh.Addr) + n.mu.dynamic.lru.Remove(e) + n.mu.dynamic.count-- e.removeLocked() e.mu.Unlock() } - n.cache[remoteAddr] = entry - n.dynamic.lru.PushFront(entry) - n.dynamic.count++ + n.mu.cache[remoteAddr] = entry + n.mu.dynamic.lru.PushFront(entry) + n.mu.dynamic.count++ return entry } @@ -128,7 +133,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, onResolve fun entry.mu.Lock() defer entry.mu.Unlock() - switch s := entry.neigh.State; s { + switch s := entry.mu.neigh.State; s { case Stale: entry.handlePacketQueuedLocked(localAddr) fallthrough @@ -139,19 +144,19 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, onResolve fun // a node continues sending packets to that neighbor using the cached // link-layer address." if onResolve != nil { - onResolve(LinkResolutionResult{LinkAddress: entry.neigh.LinkAddr, Success: true}) + onResolve(LinkResolutionResult{LinkAddress: entry.mu.neigh.LinkAddr, Success: true}) } - return entry.neigh, nil, nil - case Unknown, Incomplete, Failed: + return entry.mu.neigh, nil, nil + case Unknown, Incomplete, Unreachable: if onResolve != nil { - entry.onResolve = append(entry.onResolve, onResolve) + entry.mu.onResolve = append(entry.mu.onResolve, onResolve) } - if entry.done == nil { + if entry.mu.done == nil { // Address resolution needs to be initiated. - entry.done = make(chan struct{}) + entry.mu.done = make(chan struct{}) } entry.handlePacketQueuedLocked(localAddr) - return entry.neigh, entry.done, &tcpip.ErrWouldBlock{} + return entry.mu.neigh, entry.mu.done, &tcpip.ErrWouldBlock{} default: panic(fmt.Sprintf("Invalid cache entry state: %s", s)) } @@ -162,10 +167,10 @@ func (n *neighborCache) entries() []NeighborEntry { n.mu.RLock() defer n.mu.RUnlock() - entries := make([]NeighborEntry, 0, len(n.cache)) - for _, entry := range n.cache { + entries := make([]NeighborEntry, 0, len(n.mu.cache)) + for _, entry := range n.mu.cache { entry.mu.RLock() - entries = append(entries, entry.neigh) + entries = append(entries, entry.mu.neigh) entry.mu.RUnlock() } return entries @@ -181,19 +186,19 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd n.mu.Lock() defer n.mu.Unlock() - if entry, ok := n.cache[addr]; ok { + if entry, ok := n.mu.cache[addr]; ok { entry.mu.Lock() - if entry.neigh.State != Static { + if entry.mu.neigh.State != Static { // Dynamic entry found with the same address. - n.dynamic.lru.Remove(entry) - n.dynamic.count-- - } else if entry.neigh.LinkAddr == linkAddr { + n.mu.dynamic.lru.Remove(entry) + n.mu.dynamic.count-- + } else if entry.mu.neigh.LinkAddr == linkAddr { // Static entry found with the same address and link address. entry.mu.Unlock() return } else { // Static entry found with the same address but different link address. - entry.neigh.LinkAddr = linkAddr + entry.mu.neigh.LinkAddr = linkAddr entry.dispatchChangeEventLocked() entry.mu.Unlock() return @@ -203,7 +208,12 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd entry.mu.Unlock() } - n.cache[addr] = newStaticNeighborEntry(n, addr, linkAddr, n.state) + entry := newStaticNeighborEntry(n, addr, linkAddr, n.state) + n.mu.cache[addr] = entry + + entry.mu.Lock() + defer entry.mu.Unlock() + entry.dispatchAddEventLocked() } // removeEntry removes a dynamic or static entry by address from the neighbor @@ -212,7 +222,7 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool { n.mu.Lock() defer n.mu.Unlock() - entry, ok := n.cache[addr] + entry, ok := n.mu.cache[addr] if !ok { return false } @@ -220,13 +230,13 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool { entry.mu.Lock() defer entry.mu.Unlock() - if entry.neigh.State != Static { - n.dynamic.lru.Remove(entry) - n.dynamic.count-- + if entry.mu.neigh.State != Static { + n.mu.dynamic.lru.Remove(entry) + n.mu.dynamic.count-- } entry.removeLocked() - delete(n.cache, entry.neigh.Addr) + delete(n.mu.cache, entry.mu.neigh.Addr) return true } @@ -235,15 +245,15 @@ func (n *neighborCache) clear() { n.mu.Lock() defer n.mu.Unlock() - for _, entry := range n.cache { + for _, entry := range n.mu.cache { entry.mu.Lock() entry.removeLocked() entry.mu.Unlock() } - n.dynamic.lru = neighborEntryList{} - n.cache = make(map[tcpip.Address]*neighborEntry) - n.dynamic.count = 0 + n.mu.dynamic.lru = neighborEntryList{} + n.mu.cache = make(map[tcpip.Address]*neighborEntry) + n.mu.dynamic.count = 0 } // config returns the NUD configuration. @@ -260,30 +270,6 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { n.state.SetConfig(config) } -var _ neighborTable = (*neighborCache)(nil) - -func (n *neighborCache) neighbors() ([]NeighborEntry, tcpip.Error) { - return n.entries(), nil -} - -func (n *neighborCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - entry, ch, err := n.entry(addr, localAddr, onResolve) - return entry.LinkAddr, ch, err -} - -func (n *neighborCache) remove(addr tcpip.Address) tcpip.Error { - if !n.removeEntry(addr) { - return &tcpip.ErrBadAddress{} - } - - return nil -} - -func (n *neighborCache) removeAll() tcpip.Error { - n.clear() - return nil -} - // handleProbe handles a neighbor probe as defined by RFC 4861 section 7.2.3. // // Validation of the probe is expected to be handled by the caller. @@ -300,7 +286,7 @@ func (n *neighborCache) handleProbe(remoteAddr tcpip.Address, remoteLinkAddr tcp // Validation of the confirmation is expected to be handled by the caller. func (n *neighborCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { n.mu.RLock() - entry, ok := n.cache[addr] + entry, ok := n.mu.cache[addr] n.mu.RUnlock() if ok { entry.mu.Lock() @@ -316,7 +302,7 @@ func (n *neighborCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.Li // some protocol that operates at a layer above the IP/link layer. func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { n.mu.RLock() - entry, ok := n.cache[addr] + entry, ok := n.mu.cache[addr] n.mu.RUnlock() if ok { entry.mu.Lock() @@ -325,11 +311,13 @@ func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { } } -func (n *neighborCache) nudConfig() (NUDConfigurations, tcpip.Error) { - return n.config(), nil -} - -func (n *neighborCache) setNUDConfig(c NUDConfigurations) tcpip.Error { - n.setConfig(c) - return nil +func (n *neighborCache) init(nic *nic, r LinkAddressResolver) { + *n = neighborCache{ + nic: nic, + state: NewNUDState(nic.stack.nudConfigs, nic.stack.randomGenerator), + linkRes: r, + } + n.mu.Lock() + n.mu.cache = make(map[tcpip.Address]*neighborEntry, neighborCacheSize) + n.mu.Unlock() } diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index b489b5e08..909912662 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -84,19 +84,16 @@ func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, cl entries: newTestEntryStore(), delay: typicalLatency, } - linkRes.neigh = &neighborCache{ - nic: &NIC{ - stack: &Stack{ - clock: clock, - nudDisp: nudDisp, - }, - id: 1, - stats: makeNICStats(), + linkRes.neigh.init(&nic{ + stack: &Stack{ + clock: clock, + nudDisp: nudDisp, + nudConfigs: config, + randomGenerator: rng, }, - state: NewNUDState(config, rng), - linkRes: linkRes, - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), - } + id: 1, + stats: makeNICStats(), + }, linkRes) return linkRes } @@ -190,7 +187,7 @@ func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { // neighbor probe. type testNeighborResolver struct { clock tcpip.Clock - neigh *neighborCache + neigh neighborCache entries *testEntryStore delay time.Duration onLinkAddressRequest func() @@ -1613,12 +1610,11 @@ func TestNeighborCacheRetryResolution(t *testing.T) { } } - // Verify the entry is in Failed state. wantEntries := []NeighborEntry{ { Addr: entry.Addr, LinkAddr: "", - State: Failed, + State: Unreachable, }, } if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index b05f96d4f..03fef52ee 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -38,7 +38,8 @@ type NeighborEntry struct { } // NeighborState defines the state of a NeighborEntry within the Neighbor -// Unreachability Detection state machine, as per RFC 4861 section 7.3.2. +// Unreachability Detection state machine, as per RFC 4861 section 7.3.2 and +// RFC 7048. type NeighborState uint8 const ( @@ -61,15 +62,33 @@ const ( Delay // Probe means a reachability confirmation is actively being sought by // periodically retransmitting reachability probes until a reachability - // confirmation is received, or until the max amount of probes has been sent. + // confirmation is received, or until the maximum number of probes has been + // sent. Probe // Static describes entries that have been explicitly added by the user. They // do not expire and are not deleted until explicitly removed. Static - // Failed means recent attempts of reachability have returned inconclusive. + // Failed is deprecated and should no longer be used. + // + // TODO(gvisor.dev/issue/4667): Remove this once all references to Failed + // are removed from Fuchsia. Failed + // Unreachable means reachability confirmation failed; the maximum number of + // reachability probes has been sent and no replies have been received. + // + // TODO(gvisor.dev/issue/5472): Add the following sentence when we implement + // RFC 7048: "Packets continue to be sent to the neighbor while + // re-attempting to resolve the address." + Unreachable ) +type timer struct { + // done indicates to the timer that the timer was stopped. + done *bool + + timer tcpip.Timer +} + // neighborEntry implements a neighbor entry's individual node behavior, as per // RFC 4861 section 7.3.3. Neighbor Unreachability Detection operates in // parallel with the sending of packets to a neighbor, necessitating the @@ -82,20 +101,22 @@ type neighborEntry struct { // nudState points to the Neighbor Unreachability Detection configuration. nudState *NUDState - // mu protects the fields below. - mu sync.RWMutex + mu struct { + sync.RWMutex + + neigh NeighborEntry - neigh NeighborEntry + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. + done chan struct{} - // done is closed when address resolution is complete. It is nil iff s is - // incomplete and resolution is not yet in progress. - done chan struct{} + // onResolve is called with the result of address resolution. + onResolve []func(LinkResolutionResult) - // onResolve is called with the result of address resolution. - onResolve []func(LinkResolutionResult) + isRouter bool - isRouter bool - job *tcpip.Job + timer timer + } } // newNeighborEntry creates a neighbor cache entry starting at the default @@ -103,14 +124,18 @@ type neighborEntry struct { // `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created // neighborEntry. func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState *NUDState) *neighborEntry { - return &neighborEntry{ + n := &neighborEntry{ cache: cache, nudState: nudState, - neigh: NeighborEntry{ - Addr: remoteAddr, - State: Unknown, - }, } + n.mu.Lock() + n.mu.neigh = NeighborEntry{ + Addr: remoteAddr, + State: Unknown, + } + n.mu.Unlock() + return n + } // newStaticNeighborEntry creates a neighbor cache entry starting at the @@ -123,14 +148,14 @@ func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr t State: Static, UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(), } - if nudDisp := cache.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborAdded(cache.nic.id, entry) - } - return &neighborEntry{ + n := &neighborEntry{ cache: cache, nudState: state, - neigh: entry, } + n.mu.Lock() + n.mu.neigh = entry + n.mu.Unlock() + return n } // notifyCompletionLocked notifies those waiting for address resolution, with @@ -138,14 +163,14 @@ func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr t // // Precondition: e.mu MUST be locked. func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { - res := LinkResolutionResult{LinkAddress: e.neigh.LinkAddr, Success: succeeded} - for _, callback := range e.onResolve { + res := LinkResolutionResult{LinkAddress: e.mu.neigh.LinkAddr, Success: succeeded} + for _, callback := range e.mu.onResolve { callback(res) } - e.onResolve = nil - if ch := e.done; ch != nil { + e.mu.onResolve = nil + if ch := e.mu.done; ch != nil { close(ch) - e.done = nil + e.mu.done = nil // Dequeue the pending packets in a new goroutine to not hold up the current // goroutine as writing packets may be a costly operation. // @@ -153,7 +178,7 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { // is resolved (which ends up obtaining the entry's lock) while holding the // link resolution queue's lock. Dequeuing packets in a new goroutine avoids // a lock ordering violation. - go e.cache.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) + go e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, succeeded) } } @@ -163,7 +188,7 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchAddEventLocked() { if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborAdded(e.cache.nic.id, e.neigh) + nudDisp.OnNeighborAdded(e.cache.nic.id, e.mu.neigh) } } @@ -173,7 +198,7 @@ func (e *neighborEntry) dispatchAddEventLocked() { // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchChangeEventLocked() { if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborChanged(e.cache.nic.id, e.neigh) + nudDisp.OnNeighborChanged(e.cache.nic.id, e.mu.neigh) } } @@ -183,17 +208,20 @@ func (e *neighborEntry) dispatchChangeEventLocked() { // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchRemoveEventLocked() { if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborRemoved(e.cache.nic.id, e.neigh) + nudDisp.OnNeighborRemoved(e.cache.nic.id, e.mu.neigh) } } -// cancelJobLocked cancels the currently scheduled action, if there is one. +// cancelTimerLocked cancels the currently scheduled action, if there is one. // Entries in Unknown, Stale, or Static state do not have a scheduled action. // // Precondition: e.mu MUST be locked. -func (e *neighborEntry) cancelJobLocked() { - if job := e.job; job != nil { - job.Cancel() +func (e *neighborEntry) cancelTimerLocked() { + if e.mu.timer.timer != nil { + e.mu.timer.timer.Stop() + *e.mu.timer.done = true + + e.mu.timer = timer{} } } @@ -201,9 +229,9 @@ func (e *neighborEntry) cancelJobLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) removeLocked() { - e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() e.dispatchRemoveEventLocked() - e.cancelJobLocked() + e.cancelTimerLocked() e.notifyCompletionLocked(false /* succeeded */) } @@ -213,61 +241,98 @@ func (e *neighborEntry) removeLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) setStateLocked(next NeighborState) { - e.cancelJobLocked() + e.cancelTimerLocked() - prev := e.neigh.State - e.neigh.State = next - e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + prev := e.mu.neigh.State + e.mu.neigh.State = next + e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() config := e.nudState.Config() switch next { case Incomplete: - panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev)) + panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.mu.neigh, prev)) case Reachable: - e.job = e.cache.nic.stack.newJob(&e.mu, func() { - e.setStateLocked(Stale) - e.dispatchChangeEventLocked() - }) - e.job.Schedule(e.nudState.ReachableTime()) + // Protected by e.mu. + done := false + + e.mu.timer = timer{ + done: &done, + timer: e.cache.nic.stack.Clock().AfterFunc(e.nudState.ReachableTime(), func() { + e.mu.Lock() + defer e.mu.Unlock() + + if done { + // The timer was stopped because the entry changed state. + return + } + + e.setStateLocked(Stale) + e.dispatchChangeEventLocked() + }), + } case Delay: - e.job = e.cache.nic.stack.newJob(&e.mu, func() { - e.setStateLocked(Probe) - e.dispatchChangeEventLocked() - }) - e.job.Schedule(config.DelayFirstProbeTime) + // Protected by e.mu. + done := false + + e.mu.timer = timer{ + done: &done, + timer: e.cache.nic.stack.Clock().AfterFunc(config.DelayFirstProbeTime, func() { + e.mu.Lock() + defer e.mu.Unlock() + + if done { + // The timer was stopped because the entry changed state. + return + } - case Probe: - var retryCounter uint32 - var sendUnicastProbe func() - - sendUnicastProbe = func() { - if retryCounter == config.MaxUnicastProbes { - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } + e.setStateLocked(Probe) + e.dispatchChangeEventLocked() + }), + } - if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr); err != nil { - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } + case Probe: + // Protected by e.mu. + done := false - retryCounter++ - e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe) - e.job.Schedule(config.RetransmitTimer) - } + remaining := config.MaxUnicastProbes + addr := e.mu.neigh.Addr + linkAddr := e.mu.neigh.LinkAddr // Send a probe in another gorountine to free this thread of execution - // for finishing the state transition. This is necessary to avoid - // deadlock where sending and processing probes are done synchronously, - // such as loopback and integration tests. - e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe) - e.job.Schedule(immediateDuration) + // for finishing the state transition. This is necessary to escape the + // currently held lock so we can send the probe message without holding + // a shared lock. + e.mu.timer = timer{ + done: &done, + timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + var err tcpip.Error + timedoutResolution := remaining == 0 + if !timedoutResolution { + err = e.cache.linkRes.LinkAddressRequest(addr, "" /* localAddr */, linkAddr) + } + + e.mu.Lock() + defer e.mu.Unlock() + + if done { + // The timer was stopped because the entry changed state. + return + } - case Failed: + if timedoutResolution || err != nil { + e.setStateLocked(Unreachable) + e.dispatchChangeEventLocked() + return + } + + remaining-- + e.mu.timer.timer.Reset(config.RetransmitTimer) + }), + } + + case Unreachable: e.notifyCompletionLocked(false /* succeeded */) case Unknown, Stale, Static: @@ -285,76 +350,66 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { - switch e.neigh.State { - case Failed: - e.cache.nic.stats.Neighbor.FailedEntryLookups.Increment() + switch e.mu.neigh.State { + case Unknown, Unreachable: + prev := e.mu.neigh.State + e.mu.neigh.State = Incomplete + e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + + switch prev { + case Unknown: + e.dispatchAddEventLocked() + case Unreachable: + e.dispatchChangeEventLocked() + e.cache.nic.stats.Neighbor.UnreachableEntryLookups.Increment() + } - fallthrough - case Unknown: - e.neigh.State = Incomplete - e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + config := e.nudState.Config() - e.dispatchAddEventLocked() + // Protected by e.mu. + done := false - config := e.nudState.Config() + remaining := config.MaxMulticastProbes + addr := e.mu.neigh.Addr - var retryCounter uint32 - var sendMulticastProbe func() - - sendMulticastProbe = func() { - if retryCounter == config.MaxMulticastProbes { - // "If no Neighbor Advertisement is received after - // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. - // The sender MUST return ICMP destination unreachable indications with - // code 3 (Address Unreachable) for each packet queued awaiting address - // resolution." - RFC 4861 section 7.2.2 - // - // There is no need to send an ICMP destination unreachable indication - // since the failure to resolve the address is expected to only occur - // on this node. Thus, redirecting traffic is currently not supported. - // - // "If the error occurs on a node other than the node originating the - // packet, an ICMP error message is generated. If the error occurs on - // the originating node, an implementation is not required to actually - // create and send an ICMP error packet to the source, as long as the - // upper-layer sender is notified through an appropriate mechanism - // (e.g. return value from a procedure call). Note, however, that an - // implementation may find it convenient in some cases to return errors - // to the sender by taking the offending packet, generating an ICMP - // error message, and then delivering it (locally) through the generic - // error-handling routines." - RFC 4861 section 2.1 - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } + // Send a probe in another gorountine to free this thread of execution + // for finishing the state transition. This is necessary to escape the + // currently held lock so we can send the probe message without holding + // a shared lock. + e.mu.timer = timer{ + done: &done, + timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + var err tcpip.Error + timedoutResolution := remaining == 0 + if !timedoutResolution { + // As per RFC 4861 section 7.2.2: + // + // If the source address of the packet prompting the solicitation is + // the same as one of the addresses assigned to the outgoing interface, + // that address SHOULD be placed in the IP Source Address of the + // outgoing solicitation. + // + err = e.cache.linkRes.LinkAddressRequest(addr, localAddr, "" /* linkAddr */) + } - // As per RFC 4861 section 7.2.2: - // - // If the source address of the packet prompting the solicitation is the - // same as one of the addresses assigned to the outgoing interface, that - // address SHOULD be placed in the IP Source Address of the outgoing - // solicitation. - // - if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, ""); err != nil { - // There is no need to log the error here; the NUD implementation may - // assume a working link. A valid link should be the responsibility of - // the NIC/stack.LinkEndpoint. - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } + e.mu.Lock() + defer e.mu.Unlock() - retryCounter++ - e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe) - e.job.Schedule(config.RetransmitTimer) - } + if done { + // The timer was stopped because the entry changed state. + return + } - // Send a probe in another gorountine to free this thread of execution - // for finishing the state transition. This is necessary to avoid - // deadlock where sending and processing probes are done synchronously, - // such as loopback and integration tests. - e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe) - e.job.Schedule(immediateDuration) + if timedoutResolution || err != nil { + e.setStateLocked(Unreachable) + e.dispatchChangeEventLocked() + return + } + + remaining-- + e.mu.timer.timer.Reset(config.RetransmitTimer) + }), + } case Stale: e.setStateLocked(Delay) @@ -363,7 +418,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { case Incomplete, Reachable, Delay, Probe, Static: // Do nothing default: - panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State)) } } @@ -378,9 +433,9 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These // checks MUST be done by the NetworkEndpoint. - switch e.neigh.State { - case Unknown, Failed: - e.neigh.LinkAddr = remoteLinkAddr + switch e.mu.neigh.State { + case Unknown: + e.mu.neigh.LinkAddr = remoteLinkAddr e.setStateLocked(Stale) e.dispatchAddEventLocked() @@ -390,29 +445,36 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // cached address should be replaced by the received address, and the // entry's reachability state MUST be set to STALE." // - RFC 4861 section 7.2.3 - e.neigh.LinkAddr = remoteLinkAddr + e.mu.neigh.LinkAddr = remoteLinkAddr e.setStateLocked(Stale) e.notifyCompletionLocked(true /* succeeded */) e.dispatchChangeEventLocked() case Reachable, Delay, Probe: - if e.neigh.LinkAddr != remoteLinkAddr { - e.neigh.LinkAddr = remoteLinkAddr + if e.mu.neigh.LinkAddr != remoteLinkAddr { + e.mu.neigh.LinkAddr = remoteLinkAddr e.setStateLocked(Stale) e.dispatchChangeEventLocked() } case Stale: - if e.neigh.LinkAddr != remoteLinkAddr { - e.neigh.LinkAddr = remoteLinkAddr + if e.mu.neigh.LinkAddr != remoteLinkAddr { + e.mu.neigh.LinkAddr = remoteLinkAddr e.dispatchChangeEventLocked() } + case Unreachable: + // TODO(gvisor.dev/issue/5472): Do not change the entry if the link + // address is the same, as per RFC 7048. + e.mu.neigh.LinkAddr = remoteLinkAddr + e.setStateLocked(Stale) + e.dispatchChangeEventLocked() + case Static: // Do nothing default: - panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State)) } } @@ -430,7 +492,7 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { - switch e.neigh.State { + switch e.mu.neigh.State { case Incomplete: if len(linkAddr) == 0 { // "If the link layer has addresses and no Target Link-Layer Address @@ -439,35 +501,35 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla break } - e.neigh.LinkAddr = linkAddr + e.mu.neigh.LinkAddr = linkAddr if flags.Solicited { e.setStateLocked(Reachable) } else { e.setStateLocked(Stale) } e.dispatchChangeEventLocked() - e.isRouter = flags.IsRouter + e.mu.isRouter = flags.IsRouter e.notifyCompletionLocked(true /* succeeded */) // "Note that the Override flag is ignored if the entry is in the // INCOMPLETE state." - RFC 4861 section 7.2.5 case Reachable, Stale, Delay, Probe: - isLinkAddrDifferent := len(linkAddr) != 0 && e.neigh.LinkAddr != linkAddr + isLinkAddrDifferent := len(linkAddr) != 0 && e.mu.neigh.LinkAddr != linkAddr if isLinkAddrDifferent { if !flags.Override { - if e.neigh.State == Reachable { + if e.mu.neigh.State == Reachable { e.setStateLocked(Stale) e.dispatchChangeEventLocked() } break } - e.neigh.LinkAddr = linkAddr + e.mu.neigh.LinkAddr = linkAddr if !flags.Solicited { - if e.neigh.State != Stale { + if e.mu.neigh.State != Stale { e.setStateLocked(Stale) e.dispatchChangeEventLocked() } else { @@ -479,7 +541,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla } if flags.Solicited && (flags.Override || !isLinkAddrDifferent) { - wasReachable := e.neigh.State == Reachable + wasReachable := e.mu.neigh.State == Reachable // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) e.notifyCompletionLocked(true /* succeeded */) @@ -488,7 +550,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla } } - if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) { + if e.mu.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.mu.neigh.Addr) { // "In those cases where the IsRouter flag changes from TRUE to FALSE as // a result of this update, the node MUST remove that router from the // Default Router List and update the Destination Cache entries for all @@ -505,16 +567,16 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla } if ndpEP, ok := ep.(NDPEndpoint); ok { - ndpEP.InvalidateDefaultRouter(e.neigh.Addr) + ndpEP.InvalidateDefaultRouter(e.mu.neigh.Addr) } } - e.isRouter = flags.IsRouter + e.mu.isRouter = flags.IsRouter - case Unknown, Failed, Static: + case Unknown, Unreachable, Static: // Do nothing default: - panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State)) } } @@ -523,19 +585,19 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // // Precondition: e.mu MUST be locked. func (e *neighborEntry) handleUpperLevelConfirmationLocked() { - switch e.neigh.State { + switch e.mu.neigh.State { case Reachable, Stale, Delay, Probe: - wasReachable := e.neigh.State == Reachable + wasReachable := e.mu.neigh.State == Reachable // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) if !wasReachable { e.dispatchChangeEventLocked() } - case Unknown, Incomplete, Failed, Static: + case Unknown, Incomplete, Unreachable, Static: // Do nothing default: - panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State)) } } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 57cfbdb8b..47a9e2448 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -70,38 +70,39 @@ func eventDiffOptsWithSort() []cmp.Option { } // The following unit tests exercise every state transition and verify its -// behavior with RFC 4681. +// behavior with RFC 4681 and RFC 7048. // -// | From | To | Cause | Update | Action | Event | -// | ========== | ========== | ========================================== | ======== | ===========| ======= | -// | Unknown | Unknown | Confirmation w/ unknown address | | | Added | -// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added | -// | Unknown | Stale | Probe | | | Added | -// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed | -// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed | -// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed | -// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed | -// | Incomplete | Failed | Max probes sent without reply | | Notify | Removed | -// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | | -// | Reachable | Stale | Reachable timer expired | | | Changed | -// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed | -// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed | -// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed | -// | Stale | Stale | Override confirmation | LinkAddr | | Changed | -// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed | -// | Stale | Delay | Packet sent | | | Changed | -// | Delay | Reachable | Upper-layer confirmation | | | Changed | -// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed | -// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed | -// | Delay | Stale | Probe or confirmation w/ different address | | | Changed | -// | Delay | Probe | Delay timer expired | | Send probe | Changed | -// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed | -// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed | -// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed | -// | Probe | Stale | Probe or confirmation w/ different address | | | Changed | -// | Probe | Probe | Retransmit timer expired | | | Changed | -// | Probe | Failed | Max probes sent without reply | | Notify | Removed | -// | Failed | Incomplete | Packet queued | | Send probe | Added | +// | From | To | Cause | Update | Action | Event | +// | =========== | =========== | ========================================== | ======== | ===========| ======= | +// | Unknown | Unknown | Confirmation w/ unknown address | | | Added | +// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added | +// | Unknown | Stale | Probe | | | Added | +// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed | +// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed | +// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed | +// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed | +// | Incomplete | Unreachable | Max probes sent without reply | | Notify | Changed | +// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | | +// | Reachable | Stale | Reachable timer expired | | | Changed | +// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed | +// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Stale | Stale | Override confirmation | LinkAddr | | Changed | +// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed | +// | Stale | Delay | Packet sent | | | Changed | +// | Delay | Reachable | Upper-layer confirmation | | | Changed | +// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Delay | Stale | Probe or confirmation w/ different address | | | Changed | +// | Delay | Probe | Delay timer expired | | Send probe | Changed | +// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed | +// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Probe | Stale | Probe or confirmation w/ different address | | | Changed | +// | Probe | Probe | Retransmit timer expired | | | Changed | +// | Probe | Unreachable | Max probes sent without reply | | Notify | Changed | +// | Unreachable | Incomplete | Packet queued | | Send probe | Changed | +// | Unreachable | Stale | Probe w/ different address | LinkAddr | | Changed | type testEntryEventType uint8 @@ -220,13 +221,15 @@ func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) { clock := faketime.NewManualClock() disp := testNUDDispatcher{} - nic := NIC{ + nic := nic{ LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint id: entryTestNICID, stack: &Stack{ - clock: clock, - nudDisp: &disp, + clock: clock, + nudDisp: &disp, + nudConfigs: c, + randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())), }, stats: makeNICStats(), } @@ -235,23 +238,18 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e header.IPv6ProtocolNumber: netEP, } - rng := rand.New(rand.NewSource(time.Now().UnixNano())) - nudState := NewNUDState(c, rng) var linkRes entryTestLinkResolver // Stub out the neighbor cache to verify deletion from the cache. - neigh := &neighborCache{ - nic: &nic, - state: nudState, - linkRes: &linkRes, - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), - } - l := linkResolver{ - resolver: &linkRes, - neighborTable: neigh, - } - entry := newNeighborEntry(neigh, entryTestAddr1 /* remoteAddr */, nudState) - neigh.cache[entryTestAddr1] = entry - nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]linkResolver{ + l := &linkResolver{ + resolver: &linkRes, + } + l.neigh.init(&nic, &linkRes) + + entry := newNeighborEntry(&l.neigh, entryTestAddr1 /* remoteAddr */, l.neigh.state) + l.neigh.mu.Lock() + l.neigh.mu.cache[entryTestAddr1] = entry + l.neigh.mu.Unlock() + nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]*linkResolver{ header.IPv6ProtocolNumber: l, } @@ -265,8 +263,8 @@ func TestEntryInitiallyUnknown(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - if e.neigh.State != Unknown { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown) + if e.mu.neigh.State != Unknown { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown) } e.mu.Unlock() @@ -298,8 +296,8 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Unknown { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown) + if e.mu.neigh.State != Unknown { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown) } e.mu.Unlock() @@ -327,8 +325,8 @@ func TestEntryUnknownToIncomplete(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -374,8 +372,8 @@ func TestEntryUnknownToStale(t *testing.T) { e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -413,10 +411,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } - updatedAtNanos := e.neigh.UpdatedAtNanos + updatedAtNanos := e.mu.neigh.UpdatedAtNanos e.mu.Unlock() clock.Advance(c.RetransmitTimer) @@ -443,15 +441,15 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.UpdatedAtNanos, updatedAtNanos; got != want { - t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want) + if got, want := e.mu.neigh.UpdatedAtNanos, updatedAtNanos; got != want { + t.Errorf("got e.mu.neigh.UpdatedAt = %q, want = %q", got, want) } e.mu.Unlock() clock.Advance(c.RetransmitTimer) // UpdatedAt should change after failing address resolution. Timing out after - // sending the last probe transitions the entry to Failed. + // sending the last probe transitions the entry to Unreachable. { wantProbes := []entryTestProbeInfo{ { @@ -481,12 +479,12 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { }, }, { - EventType: entryTestRemoved, + EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + State: Unreachable, }, }, } @@ -497,8 +495,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, notWant := e.neigh.UpdatedAtNanos, updatedAtNanos; got == notWant { - t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got) + if got, notWant := e.mu.neigh.UpdatedAtNanos, updatedAtNanos; got == notWant { + t.Errorf("expected e.mu.neigh.UpdatedAt to change, got = %q", got) } e.mu.Unlock() } @@ -509,8 +507,8 @@ func TestEntryIncompleteToReachable(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -535,8 +533,8 @@ func TestEntryIncompleteToReachable(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.mu.Unlock() @@ -573,8 +571,8 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -599,11 +597,11 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { Override: false, IsRouter: true, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if !e.isRouter { - t.Errorf("got e.isRouter = %t, want = true", e.isRouter) + if !e.mu.isRouter { + t.Errorf("got e.mu.isRouter = %t, want = true", e.mu.isRouter) } e.mu.Unlock() @@ -640,8 +638,8 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -666,8 +664,8 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -704,8 +702,8 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -726,8 +724,8 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -758,15 +756,15 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryIncompleteToFailed(t *testing.T) { +func TestEntryIncompleteToUnreachable(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -810,12 +808,12 @@ func TestEntryIncompleteToFailed(t *testing.T) { }, }, { - EventType: entryTestRemoved, + EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + State: Unreachable, }, }, } @@ -826,8 +824,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if e.neigh.State != Failed { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) + if e.mu.neigh.State != Unreachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable) } e.mu.Unlock() } @@ -870,11 +868,11 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { Override: false, IsRouter: true, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if got, want := e.isRouter, true; got != want { - t.Errorf("got e.isRouter = %t, want = %t", got, want) + if got, want := e.mu.isRouter, true; got != want { + t.Errorf("got e.mu.isRouter = %t, want = %t", got, want) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ @@ -882,11 +880,11 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { Override: false, IsRouter: false, }) - if got, want := e.isRouter, false; got != want { - t.Errorf("got e.isRouter = %t, want = %t", got, want) + if got, want := e.mu.isRouter, false; got != want { + t.Errorf("got e.mu.isRouter = %t, want = %t", got, want) } - if ipv6EP.invalidatedRtr != e.neigh.Addr { - t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.neigh.Addr) + if ipv6EP.invalidatedRtr != e.mu.neigh.Addr { + t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.mu.neigh.Addr) } e.mu.Unlock() @@ -917,8 +915,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.mu.Unlock() } @@ -952,15 +950,15 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.handleProbeLocked(entryTestLinkAddr1) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if e.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) } e.mu.Unlock() @@ -1025,8 +1023,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.mu.Unlock() @@ -1068,8 +1066,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() } @@ -1103,12 +1101,12 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.handleProbeLocked(entryTestLinkAddr2) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -1177,16 +1175,16 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -1255,16 +1253,16 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -1333,15 +1331,15 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.handleProbeLocked(entryTestLinkAddr1) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } - if e.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) } e.mu.Unlock() @@ -1401,19 +1399,19 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if e.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) } e.mu.Unlock() @@ -1482,19 +1480,19 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if e.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) } e.mu.Unlock() @@ -1563,19 +1561,19 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } - if e.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) } e.mu.Unlock() @@ -1644,15 +1642,15 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.handleProbeLocked(entryTestLinkAddr2) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } - if e.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) } e.mu.Unlock() @@ -1721,12 +1719,12 @@ func TestEntryStaleToDelay(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -1801,12 +1799,12 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } e.handleUpperLevelConfirmationLocked() - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.mu.Unlock() @@ -1901,19 +1899,19 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if e.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) } e.mu.Unlock() @@ -2008,19 +2006,19 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if e.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) } e.mu.Unlock() @@ -2109,19 +2107,19 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } - if e.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) } e.mu.Unlock() @@ -2191,12 +2189,12 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } e.handleProbeLocked(entryTestLinkAddr2) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -2275,16 +2273,16 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -2366,8 +2364,8 @@ func TestEntryDelayToProbe(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } e.mu.Unlock() @@ -2432,8 +2430,8 @@ func TestEntryDelayToProbe(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.mu.Unlock() } @@ -2490,12 +2488,12 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.handleProbeLocked(entryTestLinkAddr2) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -2605,16 +2603,16 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -2725,19 +2723,19 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + if got, want := e.mu.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", got, want) } e.mu.Unlock() @@ -2821,19 +2819,19 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + if got, want := e.mu.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", got, want) } e.mu.Unlock() @@ -2949,19 +2947,19 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + if got, want := e.mu.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", got, want) } e.mu.Unlock() @@ -3086,16 +3084,16 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.mu.Unlock() @@ -3220,16 +3218,16 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.mu.Unlock() @@ -3297,7 +3295,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing nudDisp.mu.Unlock() } -func TestEntryProbeToFailed(t *testing.T) { +func TestEntryProbeToUnreachable(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 c.MaxUnicastProbes = 3 @@ -3352,17 +3350,17 @@ func TestEntryProbeToFailed(t *testing.T) { } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.mu.Unlock() } - // Wait for the last probe to expire, causing a transition to Failed. + // Wait for the last probe to expire, causing a transition to Unreachable. clock.Advance(c.RetransmitTimer) e.mu.Lock() - if e.neigh.State != Failed { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) + if e.mu.neigh.State != Unreachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable) } e.mu.Unlock() @@ -3404,12 +3402,12 @@ func TestEntryProbeToFailed(t *testing.T) { }, }, { - EventType: entryTestRemoved, + EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, LinkAddr: entryTestLinkAddr1, - State: Probe, + State: Unreachable, }, }, } @@ -3420,7 +3418,7 @@ func TestEntryProbeToFailed(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryFailedToIncomplete(t *testing.T) { +func TestEntryUnreachableToIncomplete(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 e, nudDisp, linkRes, clock := entryTestSetup(c) @@ -3429,8 +3427,8 @@ func TestEntryFailedToIncomplete(t *testing.T) { // their expected state. e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -3464,15 +3462,15 @@ func TestEntryFailedToIncomplete(t *testing.T) { } e.mu.Lock() - if e.neigh.State != Failed { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) + if e.mu.neigh.State != Unreachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable) } e.mu.Unlock() e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -3487,7 +3485,16 @@ func TestEntryFailedToIncomplete(t *testing.T) { }, }, { - EventType: entryTestRemoved, + EventType: entryTestChanged, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Unreachable, + }, + }, + { + EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, @@ -3495,6 +3502,72 @@ func TestEntryFailedToIncomplete(t *testing.T) { State: Incomplete, }, }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryUnreachableToStale(t *testing.T) { + wantProbes := []entryTestProbeInfo{ + // The Incomplete-to-Incomplete state transition is tested here by + // verifying that 3 reachability probes were sent. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = uint32(len(wantProbes)) + e, nudDisp, linkRes, clock := entryTestSetup(c) + + // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in + // their expected state. + e.mu.Lock() + e.handlePacketQueuedLocked(entryTestAddr2) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) + } + e.mu.Unlock() + + waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) + clock.Advance(waitFor) + + linkRes.mu.Lock() + diff := cmp.Diff(wantProbes, linkRes.probes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) + } + + e.mu.Lock() + if e.mu.neigh.State != Unreachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable) + } + e.mu.Unlock() + + e.mu.Lock() + e.handleProbeLocked(entryTestLinkAddr2) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, @@ -3504,6 +3577,24 @@ func TestEntryFailedToIncomplete(t *testing.T) { State: Incomplete, }, }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Unreachable, + }, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + }, } nudDisp.mu.Lock() if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { diff --git a/pkg/tcpip/stack/neighborstate_string.go b/pkg/tcpip/stack/neighborstate_string.go index aa7311ec6..765df4d7a 100644 --- a/pkg/tcpip/stack/neighborstate_string.go +++ b/pkg/tcpip/stack/neighborstate_string.go @@ -1,4 +1,4 @@ -// Copyright 2020 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -30,11 +30,12 @@ func _() { _ = x[Probe-5] _ = x[Static-6] _ = x[Failed-7] + _ = x[Unreachable-8] } -const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticFailed" +const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticFailedUnreachable" -var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 53} +var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 53, 64} func (i NeighborState) String() string { if i >= NeighborState(len(_NeighborState_index)-1) { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 41a489047..f66db16a7 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -24,40 +24,26 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -type neighborTable interface { - neighbors() ([]NeighborEntry, tcpip.Error) - addStaticEntry(tcpip.Address, tcpip.LinkAddress) - get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) - remove(tcpip.Address) tcpip.Error - removeAll() tcpip.Error - - handleProbe(tcpip.Address, tcpip.LinkAddress) - handleConfirmation(tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) - handleUpperLevelConfirmation(tcpip.Address) - - nudConfig() (NUDConfigurations, tcpip.Error) - setNUDConfig(NUDConfigurations) tcpip.Error -} - -var _ NetworkInterface = (*NIC)(nil) - type linkResolver struct { resolver LinkAddressResolver - neighborTable neighborTable + neigh neighborCache } func (l *linkResolver) getNeighborLinkAddress(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - return l.neighborTable.get(addr, localAddr, onResolve) + entry, ch, err := l.neigh.entry(addr, localAddr, onResolve) + return entry.LinkAddr, ch, err } func (l *linkResolver) confirmReachable(addr tcpip.Address) { - l.neighborTable.handleUpperLevelConfirmation(addr) + l.neigh.handleUpperLevelConfirmation(addr) } -// NIC represents a "network interface card" to which the networking stack is +var _ NetworkInterface = (*nic)(nil) + +// nic represents a "network interface card" to which the networking stack is // attached. -type NIC struct { +type nic struct { LinkEndpoint stack *Stack @@ -69,8 +55,9 @@ type NIC struct { // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. - networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint - linkAddrResolvers map[tcpip.NetworkProtocolNumber]linkResolver + networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint + linkAddrResolvers map[tcpip.NetworkProtocolNumber]*linkResolver + duplicateAddressDetectors map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. // @@ -147,7 +134,7 @@ func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) { } // newNIC returns a new NIC using the default NDP configurations from stack. -func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { +func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *nic { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For // example, make sure that the link address it provides is a valid // unicast ethernet address. @@ -156,16 +143,17 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // observe an MTU of at least 1280 bytes. Ensure that this requirement // of IPv6 is supported on this endpoint's LinkEndpoint. - nic := &NIC{ + nic := &nic{ LinkEndpoint: ep, - stack: stack, - id: id, - name: name, - context: ctx, - stats: makeNICStats(), - networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]linkResolver), + stack: stack, + id: id, + name: name, + context: ctx, + stats: makeNICStats(), + networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]*linkResolver), + duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector), } nic.linkResQueue.init(nic) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) @@ -185,26 +173,15 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC if resolutionRequired { if r, ok := netEP.(LinkAddressResolver); ok { - l := linkResolver{ - resolver: r, - } - - if stack.useNeighborCache { - l.neighborTable = &neighborCache{ - nic: nic, - state: NewNUDState(stack.nudConfigs, stack.randomGenerator), - linkRes: r, - - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), - } - } else { - cache := new(linkAddrCache) - cache.init(nic, ageLimit, resolutionTimeout, resolutionAttempts, r) - l.neighborTable = cache - } + l := &linkResolver{resolver: r} + l.neigh.init(nic, r) nic.linkAddrResolvers[r.LinkAddressProtocol()] = l } } + + if d, ok := netEP.(DuplicateAddressDetector); ok { + nic.duplicateAddressDetectors[d.DuplicateAddressProtocol()] = d + } } nic.LinkEndpoint.Attach(nic) @@ -212,19 +189,19 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC return nic } -func (n *NIC) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint { +func (n *nic) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint { return n.networkEndpoints[proto] } // Enabled implements NetworkInterface. -func (n *NIC) Enabled() bool { +func (n *nic) Enabled() bool { return atomic.LoadUint32(&n.enabled) == 1 } // setEnabled sets the enabled status for the NIC. // // Returns true if the enabled status was updated. -func (n *NIC) setEnabled(v bool) bool { +func (n *nic) setEnabled(v bool) bool { if v { return atomic.SwapUint32(&n.enabled, 1) == 0 } @@ -234,7 +211,7 @@ func (n *NIC) setEnabled(v bool) bool { // disable disables n. // // It undoes the work done by enable. -func (n *NIC) disable() { +func (n *nic) disable() { n.mu.Lock() n.disableLocked() n.mu.Unlock() @@ -245,7 +222,7 @@ func (n *NIC) disable() { // It undoes the work done by enable. // // n MUST be locked. -func (n *NIC) disableLocked() { +func (n *nic) disableLocked() { if !n.Enabled() { return } @@ -283,7 +260,7 @@ func (n *NIC) disableLocked() { // address (ff02::1), start DAD for permanent addresses, and start soliciting // routers if the stack is not operating as a router. If the stack is also // configured to auto-generate a link-local address, one will be generated. -func (n *NIC) enable() tcpip.Error { +func (n *nic) enable() tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -303,7 +280,7 @@ func (n *NIC) enable() tcpip.Error { // remove detaches NIC from the link endpoint and releases network endpoint // resources. This guarantees no packets between this NIC and the network // stack. -func (n *NIC) remove() tcpip.Error { +func (n *nic) remove() tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -319,14 +296,14 @@ func (n *NIC) remove() tcpip.Error { } // setPromiscuousMode enables or disables promiscuous mode. -func (n *NIC) setPromiscuousMode(enable bool) { +func (n *nic) setPromiscuousMode(enable bool) { n.mu.Lock() n.mu.promiscuous = enable n.mu.Unlock() } // Promiscuous implements NetworkInterface. -func (n *NIC) Promiscuous() bool { +func (n *nic) Promiscuous() bool { n.mu.RLock() rv := n.mu.promiscuous n.mu.RUnlock() @@ -334,17 +311,17 @@ func (n *NIC) Promiscuous() bool { } // IsLoopback implements NetworkInterface. -func (n *NIC) IsLoopback() bool { +func (n *nic) IsLoopback() bool { return n.LinkEndpoint.Capabilities()&CapabilityLoopback != 0 } // WritePacket implements NetworkLinkEndpoint. -func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (n *nic) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt) return err } -func (n *NIC) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { +func (n *nic) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { switch pkt := pkt.(type) { case *PacketBuffer: if err := n.writePacket(r, gso, protocol, pkt); err != nil { @@ -358,7 +335,7 @@ func (n *NIC) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkPro } } -func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { +func (n *nic) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { routeInfo, _, err := r.resolvedFields(nil) switch err.(type) { case nil: @@ -388,14 +365,14 @@ func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProt } // WritePacketToRemote implements NetworkInterface. -func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { var r RouteInfo r.NetProto = protocol r.RemoteLinkAddress = remoteLinkAddr return n.writePacket(r, gso, protocol, pkt) } -func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (n *nic) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() @@ -412,11 +389,11 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN } // WritePackets implements NetworkLinkEndpoint. -func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (n *nic) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { return n.enqueuePacketBuffer(r, gso, protocol, &pkts) } -func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) { +func (n *nic) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { pkt.EgressRoute = r pkt.GSOOptions = gso @@ -435,15 +412,22 @@ func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocol } // setSpoofing enables or disables address spoofing. -func (n *NIC) setSpoofing(enable bool) { +func (n *nic) setSpoofing(enable bool) { n.mu.Lock() n.mu.spoofing = enable n.mu.Unlock() } +// Spoofing implements NetworkInterface. +func (n *nic) Spoofing() bool { + n.mu.RLock() + defer n.mu.RUnlock() + return n.mu.spoofing +} + // primaryAddress returns an address that can be used to communicate with // remoteAddr. -func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { +func (n *nic) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { ep, ok := n.networkEndpoints[protocol] if !ok { return nil @@ -473,11 +457,11 @@ const ( promiscuous ) -func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) AssignableAddressEndpoint { +func (n *nic) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) AssignableAddressEndpoint { return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } -func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { +func (n *nic) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { ep := n.getAddressOrCreateTempInner(protocol, addr, false, NeverPrimaryEndpoint) if ep != nil { ep.DecRef() @@ -488,7 +472,7 @@ func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres } // findEndpoint finds the endpoint, if any, with the given address. -func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { +func (n *nic) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { return n.getAddressOrCreateTemp(protocol, address, peb, spoofing) } @@ -501,7 +485,7 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // // If the address is the IPv4 broadcast address for an endpoint's network, that // endpoint will be returned. -func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getAddressBehaviour) AssignableAddressEndpoint { +func (n *nic) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getAddressBehaviour) AssignableAddressEndpoint { n.mu.RLock() var spoofingOrPromiscuous bool switch tempRef { @@ -516,7 +500,7 @@ func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, addre // getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean // is passed to indicate whether or not we should generate temporary endpoints. -func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { +func (n *nic) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { ep, ok := n.networkEndpoints[protocol] if !ok { return nil @@ -532,7 +516,7 @@ func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, // addAddress adds a new address to n, so that it starts accepting packets // targeted at the given address (and network protocol). -func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { +func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { return &tcpip.ErrUnknownProtocol{} @@ -553,7 +537,7 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo // allPermanentAddresses returns all permanent addresses associated with // this NIC. -func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { +func (n *nic) allPermanentAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { addressableEndpoint, ok := ep.(AddressableEndpoint) @@ -569,7 +553,7 @@ func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { } // primaryAddresses returns the primary addresses associated with this NIC. -func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress { +func (n *nic) primaryAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { addressableEndpoint, ok := ep.(AddressableEndpoint) @@ -589,7 +573,7 @@ func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress { // primaryAddress will return the first non-deprecated address if such an // address exists. If no non-deprecated address exists, the first deprecated // address will be returned. -func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { +func (n *nic) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { ep, ok := n.networkEndpoints[proto] if !ok { return tcpip.AddressWithPrefix{} @@ -604,7 +588,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit } // removeAddress removes an address from n. -func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error { +func (n *nic) removeAddress(addr tcpip.Address) tcpip.Error { for _, ep := range n.networkEndpoints { addressableEndpoint, ok := ep.(AddressableEndpoint) if !ok { @@ -622,7 +606,7 @@ func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error { return &tcpip.ErrBadLocalAddress{} } -func (n *NIC) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error { +func (n *nic) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error { linkRes, ok := n.linkAddrResolvers[protocol] if !ok { return &tcpip.ErrNotSupported{} @@ -637,34 +621,38 @@ func (n *NIC) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.Netwo return err } -func (n *NIC) neighbors(protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) { +func (n *nic) neighbors(protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { - return linkRes.neighborTable.neighbors() + return linkRes.neigh.entries(), nil } return nil, &tcpip.ErrNotSupported{} } -func (n *NIC) addStaticNeighbor(addr tcpip.Address, protocol tcpip.NetworkProtocolNumber, linkAddress tcpip.LinkAddress) tcpip.Error { +func (n *nic) addStaticNeighbor(addr tcpip.Address, protocol tcpip.NetworkProtocolNumber, linkAddress tcpip.LinkAddress) tcpip.Error { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { - linkRes.neighborTable.addStaticEntry(addr, linkAddress) + linkRes.neigh.addStaticEntry(addr, linkAddress) return nil } return &tcpip.ErrNotSupported{} } -func (n *NIC) removeNeighbor(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { +func (n *nic) removeNeighbor(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { - return linkRes.neighborTable.remove(addr) + if !linkRes.neigh.removeEntry(addr) { + return &tcpip.ErrBadAddress{} + } + return nil } return &tcpip.ErrNotSupported{} } -func (n *NIC) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error { +func (n *nic) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { - return linkRes.neighborTable.removeAll() + linkRes.neigh.clear() + return nil } return &tcpip.ErrNotSupported{} @@ -672,7 +660,7 @@ func (n *NIC) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error { // joinGroup adds a new endpoint for the given multicast address, if none // exists yet. Otherwise it just increments its count. -func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { +func (n *nic) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { // TODO(b/143102137): When implementing MLD, make sure MLD packets are // not sent unless a valid link-local address is available for use on n // as an MLD packet's source address must be a link-local address as @@ -693,7 +681,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address // leaveGroup decrements the count for the given multicast address, and when it // reaches zero removes the endpoint for this address. -func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { +func (n *nic) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { ep, ok := n.networkEndpoints[protocol] if !ok { return &tcpip.ErrNotSupported{} @@ -708,7 +696,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres } // isInGroup returns true if n has joined the multicast group addr. -func (n *NIC) isInGroup(addr tcpip.Address) bool { +func (n *nic) isInGroup(addr tcpip.Address) bool { for _, ep := range n.networkEndpoints { gep, ok := ep.(GroupAddressableEndpoint) if !ok { @@ -729,7 +717,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { +func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { n.mu.RLock() enabled := n.Enabled() // If the NIC is not yet enabled, don't receive any packets. @@ -777,41 +765,11 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp anyEPs.forEach(deliverPacketEPs) } - // Parse headers. - netProto := n.stack.NetworkProtocolInstance(protocol) - transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) - if !ok { - // The packet is too small to contain a network header. - n.stack.stats.MalformedRcvdPackets.Increment() - return - } - if hasTransportHdr { - pkt.TransportProtocolNumber = transProtoNum - // Parse the transport header if present. - if state, ok := n.stack.transportProtocols[transProtoNum]; ok { - state.proto.Parse(pkt) - } - } - - if n.stack.handleLocal && !n.IsLoopback() { - src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if r := n.getAddress(protocol, src); r != nil { - r.DecRef() - - // The source address is one of our own, so we never should have gotten a - // packet like this unless handleLocal is false. Loopback also calls this - // function even though the packets didn't come from the physical interface - // so don't drop those. - n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() - return - } - } - networkEndpoint.HandlePacket(pkt) } // DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. -func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { +func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { n.mu.RLock() // We do not deliver to protocol specific packet endpoints as on Linux // only ETH_P_ALL endpoints get outbound packets. @@ -831,7 +789,7 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { +func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -912,7 +870,7 @@ func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt } // DeliverTransportError implements TransportDispatcher. -func (n *NIC) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer) { +func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return @@ -940,19 +898,19 @@ func (n *NIC) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo } // ID implements NetworkInterface. -func (n *NIC) ID() tcpip.NICID { +func (n *nic) ID() tcpip.NICID { return n.id } // Name implements NetworkInterface. -func (n *NIC) Name() string { +func (n *nic) Name() string { return n.name } // nudConfigs gets the NUD configurations for n. -func (n *NIC) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) { +func (n *nic) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { - return linkRes.neighborTable.nudConfig() + return linkRes.neigh.config(), nil } return NUDConfigurations{}, &tcpip.ErrNotSupported{} @@ -962,16 +920,17 @@ func (n *NIC) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfiguration // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (n *NIC) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error { +func (n *nic) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { c.resetInvalidFields() - return linkRes.neighborTable.setNUDConfig(c) + linkRes.neigh.setConfig(c) + return nil } return &tcpip.ErrNotSupported{} } -func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { +func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -984,7 +943,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa return nil } -func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { +func (n *nic) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { n.mu.Lock() defer n.mu.Unlock() @@ -998,7 +957,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep // isValidForOutgoing returns true if the endpoint can be used to send out a // packet. It requires the endpoint to not be marked expired (i.e., its address // has been removed) unless the NIC is in spoofing mode, or temporary. -func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { +func (n *nic) isValidForOutgoing(ep AssignableAddressEndpoint) bool { n.mu.RLock() spoofing := n.mu.spoofing n.mu.RUnlock() @@ -1006,9 +965,9 @@ func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { } // HandleNeighborProbe implements NetworkInterface. -func (n *NIC) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { +func (n *nic) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { if l, ok := n.linkAddrResolvers[protocol]; ok { - l.neighborTable.handleProbe(addr, linkAddr) + l.neigh.handleProbe(addr, linkAddr) return nil } @@ -1016,11 +975,34 @@ func (n *NIC) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcp } // HandleNeighborConfirmation implements NetworkInterface. -func (n *NIC) HandleNeighborConfirmation(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) tcpip.Error { +func (n *nic) HandleNeighborConfirmation(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) tcpip.Error { if l, ok := n.linkAddrResolvers[protocol]; ok { - l.neighborTable.handleConfirmation(addr, linkAddr, flags) + l.neigh.handleConfirmation(addr, linkAddr, flags) return nil } return &tcpip.ErrNotSupported{} } + +// CheckLocalAddress implements NetworkInterface. +func (n *nic) CheckLocalAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + if n.Spoofing() { + return true + } + + if addressEndpoint := n.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + return true + } + + return false +} + +func (n *nic) checkDuplicateAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, h DADCompletionHandler) (DADCheckAddressDisposition, tcpip.Error) { + d, ok := n.duplicateAddressDetectors[protocol] + if !ok { + return 0, &tcpip.ErrNotSupported{} + } + + return d.CheckDuplicateAddress(addr, h), nil +} diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 9992d6eb4..c0f956e53 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -170,7 +170,7 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. - nic := NIC{ + nic := nic{ stats: makeNICStats(), } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index e9acef6a2..e1253f310 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -101,7 +101,6 @@ func TestNUDFunctions(t *testing.T) { clock := faketime.NewManualClock() s := stack.New(stack.Options{ NUDConfigs: stack.DefaultNUDConfigurations(), - UseNeighborCache: true, NetworkProtocols: test.netProtoFactory, Clock: clock, }) @@ -206,7 +205,6 @@ func TestDefaultNUDConfigurations(t *testing.T) { // address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: stack.DefaultNUDConfigurations(), - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -261,7 +259,6 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -318,7 +315,6 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -398,7 +394,6 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -460,7 +455,6 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -512,7 +506,6 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -564,7 +557,6 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -616,7 +608,6 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 1c651e216..dc139ebb2 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -55,7 +55,7 @@ type pendingPacket struct { // // Once link resolution completes successfully, the packets will be written. type packetsPendingLinkResolution struct { - nic *NIC + nic *nic mu struct { sync.Mutex @@ -82,7 +82,7 @@ func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(proto tcpip } } -func (f *packetsPendingLinkResolution) init(nic *NIC) { +func (f *packetsPendingLinkResolution) init(nic *nic) { f.mu.Lock() defer f.mu.Unlock() f.nic = nic diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index d589f798d..43e9e4beb 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -514,8 +515,19 @@ type NetworkInterface interface { Enabled() bool // Promiscuous returns true if the interface is in promiscuous mode. + // + // When in promiscuous mode, the interface should accept all packets. Promiscuous() bool + // Spoofing returns true if the interface is in spoofing mode. + // + // When in spoofing mode, the interface should consider all addresses as + // assigned to it. + Spoofing() bool + + // CheckLocalAddress returns true if the address exists on the interface. + CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool + // WritePacketToRemote writes the packet to the given remote link address. WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error @@ -840,7 +852,97 @@ type InjectableLinkEndpoint interface { InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error } -// A LinkAddressResolver handles link address resolution for a network protocol. +// DADResult is the result of a duplicate address detection process. +type DADResult struct { + // Resolved is true when DAD completed without detecting a duplicate address + // on the link. + // + // Ignored when Err is non-nil. + Resolved bool + + // Err is an error encountered while performing DAD. + Err tcpip.Error +} + +// DADCompletionHandler is a handler for DAD completion. +type DADCompletionHandler func(DADResult) + +// DADCheckAddressDisposition enumerates the possible return values from +// DAD.CheckDuplicateAddress. +type DADCheckAddressDisposition int + +const ( + _ DADCheckAddressDisposition = iota + + // DADDisabled indicates that DAD is disabled. + DADDisabled + + // DADStarting indicates that DAD is starting for an address. + DADStarting + + // DADAlreadyRunning indicates that DAD was already started for an address. + DADAlreadyRunning +) + +const ( + // defaultDupAddrDetectTransmits is the default number of NDP Neighbor + // Solicitation messages to send when doing Duplicate Address Detection + // for a tentative address. + // + // Default = 1 (from RFC 4862 section 5.1) + defaultDupAddrDetectTransmits = 1 +) + +// DADConfigurations holds configurations for duplicate address detection. +type DADConfigurations struct { + // The number of Neighbor Solicitation messages to send when doing + // Duplicate Address Detection for a tentative address. + // + // Note, a value of zero effectively disables DAD. + DupAddrDetectTransmits uint8 + + // The amount of time to wait between sending Neighbor Solicitation + // messages. + // + // Must be greater than or equal to 1ms. + RetransmitTimer time.Duration +} + +// DefaultDADConfigurations returns the default DAD configurations. +func DefaultDADConfigurations() DADConfigurations { + return DADConfigurations{ + DupAddrDetectTransmits: defaultDupAddrDetectTransmits, + RetransmitTimer: defaultRetransmitTimer, + } +} + +// Validate modifies the configuration with valid values. If invalid values are +// present in the configurations, the corresponding default values are used +// instead. +func (c *DADConfigurations) Validate() { + if c.RetransmitTimer < minimumRetransmitTimer { + c.RetransmitTimer = defaultRetransmitTimer + } +} + +// DuplicateAddressDetector handles checking if an address is already assigned +// to some neighboring node on the link. +type DuplicateAddressDetector interface { + // CheckDuplicateAddress checks if an address is assigned to a neighbor. + // + // If DAD is already being performed for the address, the handler will be + // called with the result of the original DAD request. + CheckDuplicateAddress(tcpip.Address, DADCompletionHandler) DADCheckAddressDisposition + + // SetDADConfiguations sets the configurations for DAD. + SetDADConfigurations(c DADConfigurations) + + // DuplicateAddressProtocol returns the network protocol the receiver can + // perform duplicate address detection for. + DuplicateAddressProtocol() tcpip.NetworkProtocolNumber +} + +// LinkAddressResolver handles link address resolution for a network protocol. type LinkAddressResolver interface { // LinkAddressRequest sends a request for the link address of the target // address. The request is broadcasted on the local network if a remote link diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index bab55ce49..e946f9fe3 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -35,7 +35,7 @@ type Route struct { // localAddressNIC is the interface the address is associated with. // TODO(gvisor.dev/issue/4548): Remove this field once we can query the // address's assigned status without the NIC. - localAddressNIC *NIC + localAddressNIC *nic mu struct { sync.RWMutex @@ -49,11 +49,11 @@ type Route struct { } // outgoingNIC is the interface this route uses to write packets. - outgoingNIC *NIC + outgoingNIC *nic // linkRes is set if link address resolution is enabled for this protocol on // the route's NIC. - linkRes linkResolver + linkRes *linkResolver } type routeInfo struct { @@ -108,7 +108,7 @@ func (r *Route) fieldsLocked() RouteInfo { // ownership of the provided local address. // // Returns an empty route if validation fails. -func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route { +func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *nic, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route { if len(localAddr) == 0 { localAddr = addressEndpoint.AddressWithPrefix().Address } @@ -140,7 +140,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp // makeRoute initializes a new route. It takes ownership of the provided // AssignableAddressEndpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route { +func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *nic, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route { if localAddressNIC.stack != outgoingNIC.stack { panic(fmt.Sprintf("cannot create a route with NICs from different stacks")) } @@ -184,7 +184,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA return r } - if r.linkRes.resolver == nil { + if r.linkRes == nil { return r } @@ -206,7 +206,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA return r } -func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route { +func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *nic, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route { r := &Route{ routeInfo: routeInfo{ NetProto: netProto, @@ -230,7 +230,7 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr // provided AssignableAddressEndpoint. // // A local route is a route to a destination that is local to the stack. -func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route { +func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *nic, localAddressEndpoint AssignableAddressEndpoint) *Route { loop := PacketLoop // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the // link endpoint level. We can remove this check once loopback interfaces @@ -400,7 +400,7 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isResolutionRequiredRLocked() bool { - return len(r.mu.remoteLinkAddress) == 0 && r.linkRes.resolver != nil && r.isValidForOutgoingRLocked() && !r.local() + return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local() } func (r *Route) isValidForOutgoing() bool { @@ -528,7 +528,7 @@ func (r *Route) IsOutboundBroadcast() bool { // "Reachable" is defined as having full-duplex communication between the // local and remote ends of the route. func (r *Route) ConfirmReachable() { - if r.linkRes.resolver != nil { + if r.linkRes != nil { r.linkRes.confirmReachable(r.nextHop()) } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a51d758d0..674c9a1ff 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -395,7 +395,7 @@ type Stack struct { } mu sync.RWMutex - nics map[tcpip.NICID]*NIC + nics map[tcpip.NICID]*nic // cleanupEndpointsMu protects cleanupEndpoints. cleanupEndpointsMu sync.Mutex @@ -434,12 +434,6 @@ type Stack struct { // nudConfigs is the default NUD configurations used by interfaces. nudConfigs NUDConfigurations - // useNeighborCache indicates whether ARP and NDP packets should be handled - // by the NIC's neighborCache instead of linkAddrCache. - // - // TODO(gvisor.dev/issue/4658): Remove this field. - useNeighborCache bool - // nudDisp is the NUD event dispatcher that is used to send the netstack // integrator NUD related events. nudDisp NUDDispatcher @@ -516,17 +510,6 @@ type Options struct { // NUDConfigs is the default NUD configurations used by interfaces. NUDConfigs NUDConfigurations - // UseNeighborCache is unused. - // - // TODO(gvisor.dev/issue/4658): Remove this field. - UseNeighborCache bool - - // UseLinkAddrCache indicates that the legacy link address cache should be - // used for link resolution. - // - // TODO(gvisor.dev/issue/4658): Remove this field. - UseLinkAddrCache bool - // NUDDisp is the NUD event dispatcher that an integrator can provide to // receive NUD related events. NUDDisp NUDDispatcher @@ -656,7 +639,7 @@ func New(opts Options) *Stack { s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - nics: make(map[tcpip.NICID]*NIC), + nics: make(map[tcpip.NICID]*nic), cleanupEndpoints: make(map[TransportEndpoint]struct{}), PortManager: ports.NewPortManager(), clock: clock, @@ -666,7 +649,6 @@ func New(opts Options) *Stack { icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), nudConfigs: opts.NUDConfigs, - useNeighborCache: !opts.UseLinkAddrCache, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, randomGenerator: mathrand.New(randSrc), @@ -1233,7 +1215,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return nic.primaryAddress(protocol), true } -func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint { +func (s *Stack) getAddressEP(nic *nic, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint { if len(localAddr) == 0 { return nic.primaryEndpoint(netProto, remoteAddr) } @@ -1244,13 +1226,13 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP // from the specified NIC. // // Precondition: s.mu must be read locked. -func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { +func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *nic, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint) if localAddressEndpoint == nil { return nil } - var outgoingNIC *NIC + var outgoingNIC *nic // Prefer a local route to the same interface as the local address. if localAddressNIC.hasAddress(netProto, remoteAddr) { outgoingNIC = localAddressNIC @@ -1319,6 +1301,11 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, return nil } +// HandleLocal returns true if non-loopback interfaces are allowed to loop packets. +func (s *Stack) HandleLocal() bool { + return s.handleLocal +} + // FindRoute creates a route to the given destination address, leaving through // the given NIC and local address (if provided). // @@ -1479,6 +1466,17 @@ func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool return ok } +// CheckDuplicateAddress performs duplicate address detection for the address on +// the specified interface. +func (s *Stack) CheckDuplicateAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, h DADCompletionHandler) (DADCheckAddressDisposition, tcpip.Error) { + nic, ok := s.nics[nicID] + if !ok { + return 0, &tcpip.ErrUnknownNICID{} + } + + return nic.checkDuplicateAddress(protocol, addr, h) +} + // CheckLocalAddress determines if the given local address exists, and if it // does, returns the id of the NIC it's bound to. Returns 0 if the address // does not exist. @@ -1493,20 +1491,16 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto return 0 } - addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) - if addressEndpoint == nil { - return 0 + if nic.CheckLocalAddress(protocol, addr) { + return nic.id } - addressEndpoint.DecRef() - - return nic.id + return 0 } // Go through all the NICs. for _, nic := range s.nics { - if addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint); addressEndpoint != nil { - addressEndpoint.DecRef() + if nic.CheckLocalAddress(protocol, addr) { return nic.id } } @@ -2062,22 +2056,6 @@ func generateRandInt64() int64 { return v } -// FindNetworkEndpoint returns the network endpoint for the given address. -func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, tcpip.Error) { - s.mu.RLock() - defer s.mu.RUnlock() - - for _, nic := range s.nics { - addressEndpoint := nic.getAddressOrCreateTempInner(netProto, address, false /* createTemp */, NeverPrimaryEndpoint) - if addressEndpoint == nil { - continue - } - addressEndpoint.DecRef() - return nic.getNetworkEndpoint(netProto), nil - } - return nil, &tcpip.ErrBadAddress{} -} - // FindNICNameFromID returns the name of the NIC for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { s.mu.RLock() @@ -2103,13 +2081,6 @@ const ( // ParsedOK indicates that a packet was successfully parsed. ParsedOK ParseResult = iota - // UnknownNetworkProtocol indicates that the network protocol is unknown. - UnknownNetworkProtocol - - // NetworkLayerParseError indicates that the network packet was not - // successfully parsed. - NetworkLayerParseError - // UnknownTransportProtocol indicates that the transport protocol is unknown. UnknownTransportProtocol @@ -2118,31 +2089,19 @@ const ( TransportLayerParseError ) -// ParsePacketBuffer parses the provided packet buffer. -func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult { - netProto, ok := s.networkProtocols[protocol] - if !ok { - return UnknownNetworkProtocol - } - - transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) - if !ok { - return NetworkLayerParseError - } - if !hasTransportHdr { - return ParsedOK - } - +// ParsePacketBufferTransport parses the provided packet buffer's transport +// header. +func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult { // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. - if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber { + if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { return ParsedOK } - pkt.TransportProtocolNumber = transProtoNum + pkt.TransportProtocolNumber = protocol // Parse the transport header if present. - state, ok := s.transportProtocols[transProtoNum] + state, ok := s.transportProtocols[protocol] if !ok { return UnknownTransportProtocol } @@ -2164,7 +2123,7 @@ func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber { return protos } -func isSubnetBroadcastOnNIC(nic *NIC, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { +func isSubnetBroadcastOnNIC(nic *nic, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { addressEndpoint := nic.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint) if addressEndpoint == nil { return false diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index b641a4aaa..92a0cb401 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -119,6 +119,10 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 { } func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { + if _, _, ok := f.proto.Parse(pkt); !ok { + return + } + // Increment the received packet count in the protocol descriptor. netHdr := pkt.NetworkHeader().View() @@ -2569,16 +2573,16 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } - ndpConfigs := ipv6.DefaultNDPConfigurations() + dadConfigs := stack.DefaultDADConfigurations() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, AutoGenLinkLocal: true, NDPDisp: &ndpDisp, + DADConfigs: dadConfigs, })}, } - e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1) + e := channel.New(int(dadConfigs.DupAddrDetectTransmits), 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -2594,7 +2598,7 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { // Wait for DAD to resolve. select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): // We should get a resolution event after 1s (default time to // resolve as per default NDP configurations). Waiting for that // resolution time + an extra 1s without a resolution event @@ -3231,7 +3235,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) { } opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ + DADConfigs: stack.DADConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, }, @@ -4320,7 +4324,6 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - UseNeighborCache: true, Clock: clock, }) e := channel.New(0, 0, "") diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 7d8d0851e..292e51d20 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -183,7 +183,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet } // handleError delivers an error to the transport endpoint identified by id. -func (epsByNIC *endpointsByNIC) handleError(n *NIC, id TransportEndpointID, transErr TransportError, pkt *PacketBuffer) { +func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, transErr TransportError, pkt *PacketBuffer) { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -599,7 +599,7 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb // endpoint. // // Returns true if the error was delivered. -func (d *transportDemuxer) deliverError(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverError(n *nic, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false |