summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/forwarder_test.go655
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go68
-rw-r--r--pkg/tcpip/stack/ndp_test.go792
-rw-r--r--pkg/tcpip/stack/nic.go94
-rw-r--r--pkg/tcpip/stack/nic_test.go2
-rw-r--r--pkg/tcpip/stack/nud_test.go16
-rw-r--r--pkg/tcpip/stack/registration.go4
-rw-r--r--pkg/tcpip/stack/route.go19
-rw-r--r--pkg/tcpip/stack/stack.go100
-rw-r--r--pkg/tcpip/stack/stack_test.go2
10 files changed, 1126 insertions, 626 deletions
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index 5a684eb9d..91165ebc7 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -51,6 +51,8 @@ type fwdTestNetworkEndpoint struct {
ep LinkEndpoint
}
+var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
+
func (f *fwdTestNetworkEndpoint) MTU() uint32 {
return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
@@ -110,11 +112,13 @@ func (*fwdTestNetworkEndpoint) Close() {}
// resolution.
type fwdTestNetworkProtocol struct {
addrCache *linkAddrCache
+ neigh *neighborCache
addrResolveDelay time.Duration
- onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress)
+ onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
}
+var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil)
var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil)
func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
@@ -141,7 +145,7 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol
return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true
}
-func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint {
+func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint {
return &fwdTestNetworkEndpoint{
nicID: nicID,
proto: f,
@@ -163,9 +167,9 @@ func (f *fwdTestNetworkProtocol) Close() {}
func (f *fwdTestNetworkProtocol) Wait() {}
func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
- if f.addrCache != nil && f.onLinkAddressResolved != nil {
+ if f.onLinkAddressResolved != nil {
time.AfterFunc(f.addrResolveDelay, func() {
- f.onLinkAddressResolved(f.addrCache, addr, remoteLinkAddr)
+ f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr)
})
}
return nil
@@ -300,13 +304,16 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco
panic("not implemented")
}
-func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
+func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
NetworkProtocols: []NetworkProtocol{proto},
+ UseNeighborCache: useNeighborCache,
})
- proto.addrCache = s.linkAddrCache
+ if !useNeighborCache {
+ proto.addrCache = s.linkAddrCache
+ }
// Enable forwarding.
s.SetForwarding(true)
@@ -337,6 +344,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f
t.Fatal("AddAddress #2 failed:", err)
}
+ if useNeighborCache {
+ // Control the neighbor cache for NIC 2.
+ nic, ok := s.nics[2]
+ if !ok {
+ t.Fatal("failed to get the neighbor cache for NIC 2")
+ }
+ proto.neigh = nic.neigh
+ }
+
// Route all packets to NIC 2.
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -350,79 +366,129 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f
}
func TestForwardingWithStaticResolver(t *testing.T) {
- // Create a network protocol with a static resolver.
- proto := &fwdTestNetworkProtocol{
- onResolveStaticAddress:
- // The network address 3 is resolved to the link address "c".
- func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
- if addr == "\x03" {
- return "c", true
- }
- return "", false
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ // Create a network protocol with a static resolver.
+ proto := &fwdTestNetworkProtocol{
+ onResolveStaticAddress:
+ // The network address 3 is resolved to the link address "c".
+ func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == "\x03" {
+ return "c", true
+ }
+ return "", false
+ },
+ }
- // Inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
+ ep1, ep2 := fwdTestNetFactory(t, proto, test.useNeighborCache)
- var p fwdTestPacketInfo
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
- select {
- case p = <-ep2.C:
- default:
- t.Fatal("packet not forwarded")
- }
+ var p fwdTestPacketInfo
- // Test that the static address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ select {
+ case p = <-ep2.C:
+ default:
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the static address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ })
}
}
func TestForwardingWithFakeResolver(t *testing.T) {
- // Create a network protocol with a fake resolver.
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
- // Any address will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ proto *fwdTestNetworkProtocol
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any address will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ },
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(remoteLinkAddr) != 0 {
+ t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
+ }
+ // Any address will be resolved to the link address "c".
+ neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ },
+ },
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)
- // Inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
- var p fwdTestPacketInfo
+ var p fwdTestPacketInfo
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ })
}
}
@@ -430,7 +496,9 @@ func TestForwardingWithNoResolver(t *testing.T) {
// Create a network protocol without a resolver.
proto := &fwdTestNetworkProtocol{}
- ep1, ep2 := fwdTestNetFactory(t, proto)
+ // Whether or not we use the neighbor cache here does not matter since
+ // neither linkAddrCache nor neighborCache will be used.
+ ep1, ep2 := fwdTestNetFactory(t, proto, false /* useNeighborCache */)
// inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
@@ -448,203 +516,334 @@ func TestForwardingWithNoResolver(t *testing.T) {
}
func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
- // Create a network protocol with a fake resolver.
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
- // Only packets to address 3 will be resolved to the
- // link address "c".
- if addr == "\x03" {
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
- }
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ proto *fwdTestNetworkProtocol
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Only packets to address 3 will be resolved to the
+ // link address "c".
+ if addr == "\x03" {
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ }
+ },
+ },
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(remoteLinkAddr) != 0 {
+ t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
+ }
+ // Only packets to address 3 will be resolved to the
+ // link address "c".
+ if addr == "\x03" {
+ neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ }
+ },
+ },
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
-
- // Inject an inbound packet to address 4 on NIC 1. This packet should
- // not be forwarded.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 4
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- // Inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf = buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- var p fwdTestPacketInfo
-
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)
+
+ // Inject an inbound packet to address 4 on NIC 1. This packet should
+ // not be forwarded.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 4
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf = buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
- if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
- }
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
+ }
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ })
}
}
func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
- // Create a network protocol with a fake resolver.
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
- // Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ proto *fwdTestNetworkProtocol
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ },
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(remoteLinkAddr) != 0 {
+ t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
+ }
+ // Any packets will be resolved to the link address "c".
+ neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ },
+ },
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
-
- // Inject two inbound packets to address 3 on NIC 1.
- for i := 0; i < 2; i++ {
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
-
- for i := 0; i < 2; i++ {
- var p fwdTestPacketInfo
-
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)
- if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
- }
+ // Inject two inbound packets to address 3 on NIC 1.
+ for i := 0; i < 2; i++ {
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
+ for i := 0; i < 2; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+ })
}
}
func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
- // Create a network protocol with a fake resolver.
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
- // Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ proto *fwdTestNetworkProtocol
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ },
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(remoteLinkAddr) != 0 {
+ t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
+ }
+ // Any packets will be resolved to the link address "c".
+ neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ },
+ },
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
-
- for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
- // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- // Set the packet sequence number.
- binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
-
- for i := 0; i < maxPendingPacketsPerResolution; i++ {
- var p fwdTestPacketInfo
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
-
- b := PayloadSince(p.Pkt.NetworkHeader())
- if b[dstAddrOffset] != 3 {
- t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
- }
- if len(b) < fwdTestNetHeaderLen+2 {
- t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b)
- }
- seqNumBuf := b[fwdTestNetHeaderLen:]
-
- // The first 5 packets should not be forwarded so the sequence number should
- // start with 5.
- want := uint16(i + 5)
- if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
- t.Fatalf("got the packet #%d, want = #%d", n, want)
- }
+ for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
+ // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ // Set the packet sequence number.
+ binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
+ for i := 0; i < maxPendingPacketsPerResolution; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ b := PayloadSince(p.Pkt.NetworkHeader())
+ if b[dstAddrOffset] != 3 {
+ t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
+ }
+ if len(b) < fwdTestNetHeaderLen+2 {
+ t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b)
+ }
+ seqNumBuf := b[fwdTestNetHeaderLen:]
+
+ // The first 5 packets should not be forwarded so the sequence number should
+ // start with 5.
+ want := uint16(i + 5)
+ if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
+ t.Fatalf("got the packet #%d, want = #%d", n, want)
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+ })
}
}
func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
- // Create a network protocol with a fake resolver.
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
- // Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ proto *fwdTestNetworkProtocol
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ },
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(remoteLinkAddr) != 0 {
+ t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
+ }
+ // Any packets will be resolved to the link address "c".
+ neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ },
+ },
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
-
- for i := 0; i < maxPendingResolutions+5; i++ {
- // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
- // Each packet has a different destination address (3 to
- // maxPendingResolutions + 7).
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = byte(3 + i)
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
-
- for i := 0; i < maxPendingResolutions; i++ {
- var p fwdTestPacketInfo
-
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)
- // The first 5 packets (address 3 to 7) should not be forwarded
- // because their address resolutions are interrupted.
- if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset])
- }
+ for i := 0; i < maxPendingResolutions+5; i++ {
+ // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
+ // Each packet has a different destination address (3 to
+ // maxPendingResolutions + 7).
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = byte(3 + i)
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
+ for i := 0; i < maxPendingResolutions; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ // The first 5 packets (address 3 to 7) should not be forwarded
+ // because their address resolutions are interrupted.
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+ })
}
}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index b15b8d1cb..14fb4239b 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -275,3 +275,71 @@ func TestStaticResolution(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want))
}
}
+
+// TestCacheWaker verifies that RemoveWaker removes a waker previously added
+// through get().
+func TestCacheWaker(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+
+ // First, sanity check that wakers are working.
+ {
+ linkRes := &testLinkAddressResolver{cache: c}
+ s := sleep.Sleeper{}
+ defer s.Done()
+
+ const wakerID = 1
+ w := sleep.Waker{}
+ s.AddWaker(&w, wakerID)
+
+ e := testAddrs[0]
+
+ if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock)
+ }
+ id, ok := s.Fetch(true /* block */)
+ if !ok {
+ t.Fatal("got s.Fetch(true) = (_, false), want = (_, true)")
+ }
+ if id != wakerID {
+ t.Fatalf("got s.Fetch(true) = (%d, %t), want = (%d, true)", id, ok, wakerID)
+ }
+
+ if got, _, err := c.get(e.addr, linkRes, "", nil, nil); err != nil {
+ t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err)
+ } else if got != e.linkAddr {
+ t.Fatalf("got c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr)
+ }
+ }
+
+ // Check that RemoveWaker works.
+ {
+ linkRes := &testLinkAddressResolver{cache: c}
+ s := sleep.Sleeper{}
+ defer s.Done()
+
+ const wakerID = 2 // different than the ID used in the sanity check
+ w := sleep.Waker{}
+ s.AddWaker(&w, wakerID)
+
+ e := testAddrs[1]
+ linkRes.onLinkAddressRequest = func() {
+ // Remove the waker before the linkAddrCache has the opportunity to send
+ // a notification.
+ c.removeWaker(e.addr, &w)
+ }
+
+ if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock)
+ }
+
+ if got, err := getBlocking(c, e.addr, linkRes); err != nil {
+ t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err)
+ } else if got != e.linkAddr {
+ t.Fatalf("c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr)
+ }
+
+ if id, ok := s.Fetch(false /* block */); ok {
+ t.Fatalf("unexpected notification from waker with id %d", id)
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 21bf53010..67dc5364f 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -2787,7 +2787,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
// stack.Stack will have a default route through the router (llAddr3) installed
// and a static link-address (linkAddr3) added to the link address cache for the
// router.
-func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
+func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useNeighborCache bool) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
t.Helper()
ndpDisp := &ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
@@ -2800,7 +2800,8 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd
HandleRAs: true,
AutoGenGlobalAddresses: true,
},
- NDPDisp: ndpDisp,
+ NDPDisp: ndpDisp,
+ UseNeighborCache: useNeighborCache,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -2810,7 +2811,11 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd
Gateway: llAddr3,
NIC: nicID,
}})
- s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+ if useNeighborCache {
+ s.AddStaticNeighbor(nicID, llAddr3, linkAddr3)
+ } else {
+ s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+ }
return ndpDisp, e, s
}
@@ -2884,110 +2889,128 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA
// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when
// receiving a PI with 0 preferred lifetime.
func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
- const nicID = 1
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ const nicID = 1
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
}
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
- }
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
- // Receive PI for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- expectPrimaryAddr(addr1)
+ // Receive PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ expectPrimaryAddr(addr1)
- // Deprecate addr for prefix1 immedaitely.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, deprecatedAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- // addr should still be the primary endpoint as there are no other addresses.
- expectPrimaryAddr(addr1)
+ // Deprecate addr for prefix1 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ // addr should still be the primary endpoint as there are no other addresses.
+ expectPrimaryAddr(addr1)
- // Refresh lifetimes of addr generated from prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
+ // Refresh lifetimes of addr generated from prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
- // Receive PI for prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
- // Deprecate addr for prefix2 immedaitely.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- expectAutoGenAddrEvent(addr2, deprecatedAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr1 should be the primary endpoint now since addr2 is deprecated but
- // addr1 is not.
- expectPrimaryAddr(addr1)
- // addr2 is deprecated but if explicitly requested, it should be used.
- fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
- if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
- }
+ // Deprecate addr for prefix2 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, deprecatedAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr1 should be the primary endpoint now since addr2 is deprecated but
+ // addr1 is not.
+ expectPrimaryAddr(addr1)
+ // addr2 is deprecated but if explicitly requested, it should be used.
+ fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
+ }
- // Another PI w/ 0 preferred lifetime should not result in a deprecation
- // event.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
- }
+ // Another PI w/ 0 preferred lifetime should not result in a deprecation
+ // event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
+ }
- // Refresh lifetimes of addr generated from prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
+ // Refresh lifetimes of addr generated from prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+ })
}
- expectPrimaryAddr(addr2)
}
// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated
@@ -2996,217 +3019,236 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
const nicID = 1
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
- saved := stack.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
}
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
- expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
- t.Helper()
+ expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
}
- case <-time.After(timeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- }
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
- }
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
- // Receive PI for prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
- // Receive a PI for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr1)
+ // Receive a PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr1)
- // Refresh lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
+ // Refresh lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
- // Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr2 should be the primary endpoint now since addr1 is deprecated but
- // addr2 is not.
- expectPrimaryAddr(addr2)
- // addr1 is deprecated but if explicitly requested, it should be used.
- fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
- }
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since addr1 is deprecated but
+ // addr2 is not.
+ expectPrimaryAddr(addr2)
+ // addr1 is deprecated but if explicitly requested, it should be used.
+ fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
- // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
- // sure we do not get a deprecation event again.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr2)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
- }
+ // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
+ // sure we do not get a deprecation event again.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
- // Refresh lifetimes for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- // addr1 is the primary endpoint again since it is non-deprecated now.
- expectPrimaryAddr(addr1)
+ // Refresh lifetimes for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ // addr1 is the primary endpoint again since it is non-deprecated now.
+ expectPrimaryAddr(addr1)
- // Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr2 should be the primary endpoint now since it is not deprecated.
- expectPrimaryAddr(addr2)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
- }
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since it is not deprecated.
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
- // Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
+ // Wait for addr of prefix1 to be invalidated.
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
- // Refresh both lifetimes for addr of prefix2 to the same value.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
+ // Refresh both lifetimes for addr of prefix2 to the same value.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
- // Wait for a deprecation then invalidation events, or just an invalidation
- // event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation handlers could be handled in
- // either deprecation then invalidation, or invalidation then deprecation
- // (which should be cancelled by the invalidation handler).
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
- // If we get a deprecation event first, we should get an invalidation
- // event almost immediately after.
+ // Wait for a deprecation then invalidation events, or just an invalidation
+ // event. We need to cover both cases but cannot deterministically hit both
+ // cases because the deprecation and invalidation handlers could be handled in
+ // either deprecation then invalidation, or invalidation then deprecation
+ // (which should be cancelled by the invalidation handler).
select {
case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
+ // If we get a deprecation event first, we should get an invalidation
+ // event almost immediately after.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
+ // If we get an invalidation event first, we should not get a deprecation
+ // event after.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ case <-time.After(defaultAsyncNegativeEventTimeout):
+ }
+ } else {
+ t.Fatalf("got unexpected auto-generated event")
}
- case <-time.After(defaultAsyncPositiveEventTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
- // If we get an invalidation event first, we should not get a deprecation
- // event after.
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- case <-time.After(defaultAsyncNegativeEventTimeout):
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should not have %s in the list of addresses", addr2)
+ }
+ // Should not have any primary endpoints.
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
}
- } else {
- t.Fatalf("got unexpected auto-generated event")
- }
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should not have %s in the list of addresses", addr2)
- }
- // Should not have any primary endpoints.
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if want := (tcpip.AddressWithPrefix{}); got != want {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want)
- }
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
- }
- defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
- if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
- t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute)
+ if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
+ t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute)
+ }
+ })
}
}
@@ -3524,110 +3566,128 @@ func TestAutoGenAddrRemoval(t *testing.T) {
func TestAutoGenAddrAfterRemoval(t *testing.T) {
const nicID = 1
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
}
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache)
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
- }
-
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
- // Receive a PI to auto-generate addr1 with a large valid and preferred
- // lifetime.
- const largeLifetimeSeconds = 999
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- expectAutoGenAddrEvent(addr1, newAddr)
- expectPrimaryAddr(addr1)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
- // Add addr2 as a static address.
- protoAddr2 := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addr2,
- }
- if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
- }
- // addr2 should be more preferred now since it is at the front of the primary
- // list.
- expectPrimaryAddr(addr2)
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
- // Get a route using addr2 to increment its reference count then remove it
- // to leave it in the permanentExpired state.
- r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err)
- }
- defer r.Release()
- if err := s.RemoveAddress(nicID, addr2.Address); err != nil {
- t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err)
- }
- // addr1 should be preferred again since addr2 is in the expired state.
- expectPrimaryAddr(addr1)
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
- // Receive a PI to auto-generate addr2 as valid and preferred.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- expectAutoGenAddrEvent(addr2, newAddr)
- // addr2 should be more preferred now that it is closer to the front of the
- // primary list and not deprecated.
- expectPrimaryAddr(addr2)
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
- // Removing the address should result in an invalidation event immediately.
- // It should still be in the permanentExpired state because r is still held.
- //
- // We remove addr2 here to make sure addr2 was marked as a SLAAC address
- // (it was previously marked as a static address).
- if err := s.RemoveAddress(1, addr2.Address); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
- }
- expectAutoGenAddrEvent(addr2, invalidatedAddr)
- // addr1 should be more preferred since addr2 is in the expired state.
- expectPrimaryAddr(addr1)
+ // Receive a PI to auto-generate addr1 with a large valid and preferred
+ // lifetime.
+ const largeLifetimeSeconds = 999
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ expectPrimaryAddr(addr1)
- // Receive a PI to auto-generate addr2 as valid and deprecated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0))
- expectAutoGenAddrEvent(addr2, newAddr)
- // addr1 should still be more preferred since addr2 is deprecated, even though
- // it is closer to the front of the primary list.
- expectPrimaryAddr(addr1)
+ // Add addr2 as a static address.
+ protoAddr2 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr2,
+ }
+ if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ }
+ // addr2 should be more preferred now since it is at the front of the primary
+ // list.
+ expectPrimaryAddr(addr2)
- // Receive a PI to refresh addr2's preferred lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto gen addr event")
- default:
- }
- // addr2 should be more preferred now that it is not deprecated.
- expectPrimaryAddr(addr2)
+ // Get a route using addr2 to increment its reference count then remove it
+ // to leave it in the permanentExpired state.
+ r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err)
+ }
+ defer r.Release()
+ if err := s.RemoveAddress(nicID, addr2.Address); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err)
+ }
+ // addr1 should be preferred again since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and preferred.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr2 should be more preferred now that it is closer to the front of the
+ // primary list and not deprecated.
+ expectPrimaryAddr(addr2)
+
+ // Removing the address should result in an invalidation event immediately.
+ // It should still be in the permanentExpired state because r is still held.
+ //
+ // We remove addr2 here to make sure addr2 was marked as a SLAAC address
+ // (it was previously marked as a static address).
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ // addr1 should be more preferred since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and deprecated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr1 should still be more preferred since addr2 is deprecated, even though
+ // it is closer to the front of the primary list.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to refresh addr2's preferred lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto gen addr event")
+ default:
+ }
+ // addr2 should be more preferred now that it is not deprecated.
+ expectPrimaryAddr(addr2)
- if err := s.RemoveAddress(1, addr2.Address); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ expectPrimaryAddr(addr1)
+ })
}
- expectAutoGenAddrEvent(addr2, invalidatedAddr)
- expectPrimaryAddr(addr1)
}
// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index aff29f9cc..0c811efdb 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -21,6 +21,7 @@ import (
"sort"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -135,18 +136,8 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
}
nic.mu.ndp.initializeTempAddrState()
- // Register supported packet endpoint protocols.
- for _, netProto := range header.Ethertypes {
- nic.mu.packetEPs[netProto] = []PacketEndpoint{}
- }
- for _, netProto := range stack.networkProtocols {
- netNum := netProto.Number()
- nic.mu.packetEPs[netNum] = nil
- nic.networkEndpoints[netNum] = netProto.NewEndpoint(id, stack, nic, ep, stack)
- }
-
// Check for Neighbor Unreachability Detection support.
- if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 {
+ if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 && stack.useNeighborCache {
rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds()))
nic.neigh = &neighborCache{
nic: nic,
@@ -155,6 +146,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
}
}
+ // Register supported packet endpoint protocols.
+ for _, netProto := range header.Ethertypes {
+ nic.mu.packetEPs[netProto] = []PacketEndpoint{}
+ }
+ for _, netProto := range stack.networkProtocols {
+ netNum := netProto.Number()
+ nic.mu.packetEPs[netNum] = nil
+ nic.networkEndpoints[netNum] = netProto.NewEndpoint(id, stack, nic.neigh, nic, ep, stack)
+ }
+
nic.linkEP.Attach(nic)
return nic
@@ -431,7 +432,7 @@ func (n *NIC) setSpoofing(enable bool) {
// If an IPv6 primary endpoint is requested, Source Address Selection (as
// defined by RFC 6724 section 5) will be performed.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint {
- if protocol == header.IPv6ProtocolNumber && remoteAddr != "" {
+ if protocol == header.IPv6ProtocolNumber && len(remoteAddr) != 0 {
return n.primaryIPv6Endpoint(remoteAddr)
}
@@ -818,11 +819,24 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
}
}
- ep, ok := n.networkEndpoints[protocolAddress.Protocol]
+ netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol]
if !ok {
return nil, tcpip.ErrUnknownProtocol
}
+ var nud NUDHandler
+ if n.neigh != nil {
+ // An interface value that holds a nil concrete value is itself non-nil.
+ // For this reason, n.neigh cannot be passed directly to NewEndpoint so
+ // NetworkEndpoints don't confuse it for non-nil.
+ //
+ // See https://golang.org/doc/faq#nil_error for more information.
+ nud = n.neigh
+ }
+
+ // Create the new network endpoint.
+ ep := netProto.NewEndpoint(n.id, n.stack, nud, n, n.linkEP, n.stack)
+
isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address)
// If the address is an IPv6 address and it is a permanent address,
@@ -844,10 +858,11 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
deprecated: deprecated,
}
- // Set up cache if link address resolution exists for this protocol.
+ // Set up resolver if link address resolution exists for this protocol.
if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
- if _, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok {
+ if linkRes, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok {
ref.linkCache = n.stack
+ ref.linkRes = linkRes
}
}
@@ -1082,6 +1097,51 @@ func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
return n.removePermanentAddressLocked(addr)
}
+func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) {
+ if n.neigh == nil {
+ return nil, tcpip.ErrNotSupported
+ }
+
+ return n.neigh.entries(), nil
+}
+
+func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) {
+ if n.neigh == nil {
+ return
+ }
+
+ n.neigh.removeWaker(addr, w)
+}
+
+func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error {
+ if n.neigh == nil {
+ return tcpip.ErrNotSupported
+ }
+
+ n.neigh.addStaticEntry(addr, linkAddress)
+ return nil
+}
+
+func (n *NIC) removeNeighbor(addr tcpip.Address) *tcpip.Error {
+ if n.neigh == nil {
+ return tcpip.ErrNotSupported
+ }
+
+ if !n.neigh.removeEntry(addr) {
+ return tcpip.ErrBadAddress
+ }
+ return nil
+}
+
+func (n *NIC) clearNeighbors() *tcpip.Error {
+ if n.neigh == nil {
+ return tcpip.ErrNotSupported
+ }
+
+ n.neigh.clear()
+ return nil
+}
+
// joinGroup adds a new endpoint for the given multicast address, if none
// exists yet. Otherwise it just increments its count.
func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
@@ -1662,6 +1722,10 @@ type referencedNetworkEndpoint struct {
// protocol. Set to nil otherwise.
linkCache LinkAddressCache
+ // linkRes is set if link address resolution is enabled for this protocol.
+ // Set to nil otherwise.
+ linkRes LinkAddressResolver
+
// refs is counting references held for this endpoint. When refs hits zero it
// triggers the automatic removal of the endpoint from the NIC.
refs int32
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index d312a79eb..1e065b5c1 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -192,7 +192,7 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address)
}
// NewEndpoint implements NetworkProtocol.NewEndpoint.
-func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint {
+func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint {
return &testIPv6Endpoint{
nicID: nicID,
linkEP: linkEP,
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
index 2494ee610..2b97e5972 100644
--- a/pkg/tcpip/stack/nud_test.go
+++ b/pkg/tcpip/stack/nud_test.go
@@ -61,6 +61,7 @@ func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) {
// stack will only allocate neighbor caches if a protocol providing link
// address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ UseNeighborCache: true,
})
// No NIC with ID 1 yet.
@@ -84,7 +85,8 @@ func TestNUDConfigurationFailsForNotSupported(t *testing.T) {
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(stack.Options{
- NUDConfigs: stack.DefaultNUDConfigurations(),
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -108,7 +110,8 @@ func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) {
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(stack.Options{
- NUDConfigs: stack.DefaultNUDConfigurations(),
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -136,6 +139,7 @@ func TestDefaultNUDConfigurations(t *testing.T) {
// address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: stack.DefaultNUDConfigurations(),
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -190,6 +194,7 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -246,6 +251,7 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -325,6 +331,7 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -386,6 +393,7 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -437,6 +445,7 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -488,6 +497,7 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -539,6 +549,7 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -590,6 +601,7 @@ func TestNUDConfigurationsUnreachableTime(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index aca2f77f8..21ac38583 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -298,7 +298,7 @@ type NetworkProtocol interface {
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint creates a new endpoint of this protocol.
- NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint
+ NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -488,7 +488,7 @@ type LinkAddressResolver interface {
ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool)
// LinkAddressProtocol returns the network protocol of the
- // addresses this this resolver can resolve.
+ // addresses this resolver can resolve.
LinkAddressProtocol() tcpip.NetworkProtocolNumber
}
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index e267bebb0..c2eabde9e 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -141,6 +141,16 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
}
nextAddr = r.RemoteAddress
}
+
+ if r.ref.nic.neigh != nil {
+ entry, ch, err := r.ref.nic.neigh.entry(nextAddr, r.LocalAddress, r.ref.linkRes, waker)
+ if err != nil {
+ return ch, err
+ }
+ r.RemoteLinkAddress = entry.LinkAddr
+ return nil, nil
+ }
+
linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
if err != nil {
return ch, err
@@ -155,6 +165,12 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
if nextAddr == "" {
nextAddr = r.RemoteAddress
}
+
+ if r.ref.nic.neigh != nil {
+ r.ref.nic.neigh.removeWaker(nextAddr, waker)
+ return
+ }
+
r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker)
}
@@ -163,6 +179,9 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
//
// The NIC r uses must not be locked.
func (r *Route) IsResolutionRequired() bool {
+ if r.ref.nic.neigh != nil {
+ return r.ref.isValidForOutgoing() && r.ref.linkRes != nil && r.RemoteLinkAddress == ""
+ }
return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index a3f87c8af..7f5ed9e83 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -248,7 +248,7 @@ type RcvBufAutoTuneParams struct {
// was started.
MeasureTime time.Time
- // CopiedBytes is the number of bytes copied to userspace since
+ // CopiedBytes is the number of bytes copied to user space since
// this measure began.
CopiedBytes int
@@ -461,6 +461,10 @@ type Stack struct {
// nudConfigs is the default NUD configurations used by interfaces.
nudConfigs NUDConfigurations
+ // useNeighborCache indicates whether ARP and NDP packets should be handled
+ // by the NIC's neighborCache instead of linkAddrCache.
+ useNeighborCache bool
+
// autoGenIPv6LinkLocal determines whether or not the stack will attempt
// to auto-generate an IPv6 link-local address for newly enabled non-loopback
// NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
@@ -541,6 +545,13 @@ type Options struct {
// NUDConfigs is the default NUD configurations used by interfaces.
NUDConfigs NUDConfigurations
+ // UseNeighborCache indicates whether ARP and NDP packets should be handled
+ // by the Neighbor Unreachability Detection (NUD) state machine. This flag
+ // also enables the APIs for inspecting and modifying the neighbor table via
+ // NUDDispatcher and the following Stack methods: Neighbors, RemoveNeighbor,
+ // and ClearNeighbors.
+ UseNeighborCache bool
+
// AutoGenIPv6LinkLocal determines whether or not the stack will attempt to
// auto-generate an IPv6 link-local address for newly enabled non-loopback
// NICs.
@@ -715,6 +726,7 @@ func New(opts Options) *Stack {
seed: generateRandUint32(),
ndpConfigs: opts.NDPConfigs,
nudConfigs: opts.NUDConfigs,
+ useNeighborCache: opts.UseNeighborCache,
autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
uniqueIDGenerator: opts.UniqueID,
ndpDisp: opts.NDPDisp,
@@ -1209,8 +1221,8 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[id]
- if nic == nil {
+ nic, ok := s.nics[id]
+ if !ok {
return tcpip.ErrUnknownNICID
}
@@ -1335,8 +1347,8 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto
// If a NIC is specified, we try to find the address there only.
if nicID != 0 {
- nic := s.nics[nicID]
- if nic == nil {
+ nic, ok := s.nics[nicID]
+ if !ok {
return 0
}
@@ -1367,8 +1379,8 @@ func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[nicID]
- if nic == nil {
+ nic, ok := s.nics[nicID]
+ if !ok {
return tcpip.ErrUnknownNICID
}
@@ -1383,8 +1395,8 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[nicID]
- if nic == nil {
+ nic, ok := s.nics[nicID]
+ if !ok {
return tcpip.ErrUnknownNICID
}
@@ -1416,8 +1428,33 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
}
-// RemoveWaker implements LinkAddressCache.RemoveWaker.
+// Neighbors returns all IP to MAC address associations.
+func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+
+ if !ok {
+ return nil, tcpip.ErrUnknownNICID
+ }
+
+ return nic.neighbors()
+}
+
+// RemoveWaker removes a waker that has been added when link resolution for
+// addr was requested.
func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
+ if s.useNeighborCache {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+
+ if ok {
+ nic.removeWaker(addr, waker)
+ }
+ return
+ }
+
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1427,6 +1464,47 @@ func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.
}
}
+// AddStaticNeighbor statically associates an IP address to a MAC address.
+func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.addStaticNeighbor(addr, linkAddr)
+}
+
+// RemoveNeighbor removes an IP to MAC address association previously created
+// either automically or by AddStaticNeighbor. Returns ErrBadAddress if there
+// is no association with the provided address.
+func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.removeNeighbor(addr)
+}
+
+// ClearNeighbors removes all IP to MAC address associations.
+func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.clearNeighbors()
+}
+
// RegisterTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
@@ -1961,7 +2039,7 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres
return nil, tcpip.ErrBadAddress
}
-// FindNICNameFromID returns the name of the nic for the given NICID.
+// FindNICNameFromID returns the name of the NIC for the given NICID.
func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
s.mu.RLock()
defer s.mu.RUnlock()
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 106645c50..1deeccb89 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -197,7 +197,7 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint {
+func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint {
return &fakeNetworkEndpoint{
nicID: nicID,
proto: f,