diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 197 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache.go | 91 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache_test.go | 109 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache.go | 76 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache_test.go | 439 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry.go | 61 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry_test.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 223 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic_test.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/stack/nud.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/stack/nud_test.go | 229 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer_unsafe.go | 19 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 104 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 190 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 70 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 17 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 2 |
21 files changed, 1019 insertions, 908 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index bb30556cf..ee23c9b98 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -72,6 +72,7 @@ go_library( "nud.go", "packet_buffer.go", "packet_buffer_list.go", + "packet_buffer_unsafe.go", "pending_packets.go", "rand.go", "registration.go", diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 63a42a2ea..c24f56ece 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -41,6 +41,7 @@ const ( protocolNumberOffset = 2 ) +var _ LinkAddressResolver = (*fwdTestNetworkEndpoint)(nil) var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) // fwdTestNetworkEndpoint is a network-layer protocol endpoint. @@ -153,7 +154,6 @@ type fwdTestNetworkEndpointStats struct{} // IsNetworkEndpointStats implements stack.NetworkEndpointStats. func (*fwdTestNetworkEndpointStats) IsNetworkEndpointStats() {} -var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) // fwdTestNetworkProtocol is a network-layer protocol that implements Address @@ -161,10 +161,9 @@ var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) type fwdTestNetworkProtocol struct { stack *Stack - addrCache *linkAddrCache - neigh *neighborCache + neighborTable neighborTable addrResolveDelay time.Duration - onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) + onLinkAddressResolved func(neighborTable, tcpip.Address, tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) mu struct { @@ -197,7 +196,7 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } -func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint { +func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, dispatcher TransportDispatcher) NetworkEndpoint { e := &fwdTestNetworkEndpoint{ nic: nic, proto: f, @@ -219,23 +218,23 @@ func (*fwdTestNetworkProtocol) Close() {} func (*fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { - if f.onLinkAddressResolved != nil { - time.AfterFunc(f.addrResolveDelay, func() { - f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) +func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { + if fn := f.proto.onLinkAddressResolved; fn != nil { + time.AfterFunc(f.proto.addrResolveDelay, func() { + fn(f.proto.neighborTable, addr, remoteLinkAddr) }) } return nil } -func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if f.onResolveStaticAddress != nil { - return f.onResolveStaticAddress(addr) +func (f *fwdTestNetworkEndpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if fn := f.proto.onResolveStaticAddress; fn != nil { + return fn(addr) } return "", false } -func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return fwdTestNetNumber } @@ -401,11 +400,9 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC if !ok { t.Fatal("NIC 2 does not exist") } - if useNeighborCache { - // Control the neighbor cache for NIC 2. - proto.neigh = nic.neigh - } else { - proto.addrCache = nic.linkAddrCache + + if l, ok := nic.linkAddrResolvers[fwdTestNetNumber]; ok { + proto.neighborTable = l.neighborTable } // Route all packets to NIC 2. @@ -482,43 +479,35 @@ func TestForwardingWithFakeResolver(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any address will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any address will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -573,7 +562,7 @@ func TestForwardingWithNoResolver(t *testing.T) { func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { proto := &fwdTestNetworkProtocol{ addrResolveDelay: 50 * time.Millisecond, - onLinkAddressResolved: func(*linkAddrCache, *neighborCache, tcpip.Address, tcpip.LinkAddress) { + onLinkAddressResolved: func(neighborTable, tcpip.Address, tcpip.LinkAddress) { // Don't resolve the link address. }, } @@ -606,49 +595,38 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - cache.AddLinkAddress(addr, "c") - } - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Only packets to address 3 will be resolved to the // link address "c". if addr == "\x03" { - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) } }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) // Inject an inbound packet to address 4 on NIC 1. This packet should // not be forwarded. @@ -693,43 +671,35 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any packets will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) // Inject two inbound packets to address 3 on NIC 1. for i := 0; i < 2; i++ { @@ -769,43 +739,35 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any packets will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) for i := 0; i < maxPendingPacketsPerResolution+5; i++ { // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. @@ -864,38 +826,31 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any packets will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) for i := 0; i < maxPendingResolutions+5; i++ { // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 930b8f795..5b6b58b1d 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -24,8 +24,6 @@ import ( const linkAddrCacheSize = 512 // max cache entries -var _ LinkAddressCache = (*linkAddrCache)(nil) - // linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses. // // The entries are stored in a ring buffer, oldest entry replaced first. @@ -34,6 +32,8 @@ var _ LinkAddressCache = (*linkAddrCache)(nil) type linkAddrCache struct { nic *NIC + linkRes LinkAddressResolver + // ageLimit is how long a cache entry is valid for. ageLimit time.Duration @@ -140,7 +140,7 @@ func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { } // add adds a k -> v mapping to the cache. -func (c *linkAddrCache) AddLinkAddress(k tcpip.Address, v tcpip.LinkAddress) { +func (c *linkAddrCache) add(k tcpip.Address, v tcpip.LinkAddress) { // Calculate expiration time before acquiring the lock, since expiration is // relative to the time when information was learned, rather than when it // happened to be inserted into the cache. @@ -198,10 +198,10 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { return entry } -// get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { +// get reports any known link address for addr. +func (c *linkAddrCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { c.mu.Lock() - entry := c.getOrCreateEntryLocked(k) + entry := c.getOrCreateEntryLocked(addr) entry.mu.Lock() defer entry.mu.Unlock() c.mu.Unlock() @@ -224,7 +224,7 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA } if entry.mu.done == nil { entry.mu.done = make(chan struct{}) - go c.startAddressResolution(k, linkRes, localAddr, nic, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + go c.startAddressResolution(addr, localAddr, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{} default: @@ -232,11 +232,11 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA } } -func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.Address, localAddr tcpip.Address, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */, nic) + c.linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */) select { case now := <-time.After(c.resolutionTimeout): @@ -280,13 +280,80 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt return true } -func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { - c := &linkAddrCache{ +func (c *linkAddrCache) init(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int, linkRes LinkAddressResolver) { + *c = linkAddrCache{ nic: nic, + linkRes: linkRes, ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, } + + c.mu.Lock() c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize) - return c + c.mu.Unlock() +} + +var _ neighborTable = (*linkAddrCache)(nil) + +func (*linkAddrCache) neighbors() ([]NeighborEntry, tcpip.Error) { + return nil, &tcpip.ErrNotSupported{} +} + +func (c *linkAddrCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) { + c.add(addr, linkAddr) +} + +func (*linkAddrCache) remove(addr tcpip.Address) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + +func (*linkAddrCache) removeAll() tcpip.Error { + return &tcpip.ErrNotSupported{} +} + +func (c *linkAddrCache) handleProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress) { + if len(linkAddr) != 0 { + // NUD allows probes without a link address but linkAddrCache + // is a simple neighbor table which does not implement NUD. + // + // As per RFC 4861 section 4.3, + // + // Source link-layer address + // The link-layer address for the sender. MUST NOT be + // included when the source IP address is the + // unspecified address. Otherwise, on link layers + // that have addresses this option MUST be included in + // multicast solicitations and SHOULD be included in + // unicast solicitations. + c.add(addr, linkAddr) + } +} + +func (c *linkAddrCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { + if len(linkAddr) != 0 { + // NUD allows confirmations without a link address but linkAddrCache + // is a simple neighbor table which does not implement NUD. + // + // As per RFC 4861 section 4.4, + // + // Target link-layer address + // The link-layer address for the target, i.e., the + // sender of the advertisement. This option MUST be + // included on link layers that have addresses when + // responding to multicast solicitations. When + // responding to a unicast Neighbor Solicitation this + // option SHOULD be included. + c.add(addr, linkAddr) + } +} + +func (c *linkAddrCache) handleUpperLevelConfirmation(tcpip.Address) {} + +func (*linkAddrCache) nudConfig() (NUDConfigurations, tcpip.Error) { + return NUDConfigurations{}, &tcpip.ErrNotSupported{} +} + +func (*linkAddrCache) setNUDConfig(NUDConfigurations) tcpip.Error { + return &tcpip.ErrNotSupported{} } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 466a5e8d9..9e7f331c9 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -48,7 +48,7 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { +func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error { // TODO(gvisor.dev/issue/5141): Use a fake clock. time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { @@ -60,7 +60,7 @@ func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { for _, ta := range testAddrs { if ta.addr == addr { - r.cache.AddLinkAddress(ta.addr, ta.linkAddr) + r.cache.add(ta.addr, ta.linkAddr) break } } @@ -77,10 +77,10 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe return 1 } -func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolver) (tcpip.LinkAddress, tcpip.Error) { +func getBlocking(c *linkAddrCache, addr tcpip.Address) (tcpip.LinkAddress, tcpip.Error) { var attemptedResolution bool for { - got, ch, err := c.get(addr, linkRes, "", nil, nil) + got, ch, err := c.get(addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); ok { if attemptedResolution { return got, &tcpip.ErrTimeout{} @@ -100,27 +100,28 @@ func newEmptyNIC() *NIC { } func TestCacheOverflow(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) + var c linkAddrCache + c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) for i := len(testAddrs) - 1; i >= 0; i-- { e := testAddrs[i] - c.AddLinkAddress(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil, nil) + c.add(e.addr, e.linkAddr) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("insert %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) + t.Errorf("insert %d, c.get(%s, '', nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("insert %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) + t.Errorf("insert %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // Expect to find at least half of the most recent entries. for i := 0; i < linkAddrCacheSize/2; i++ { e := testAddrs[i] - got, _, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("check %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) + t.Errorf("check %d, c.get(%s, '', nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("check %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) + t.Errorf("check %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // The earliest entries should no longer be in the cache. @@ -135,15 +136,16 @@ func TestCacheOverflow(t *testing.T) { } func TestCacheConcurrent(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, linkRes) var wg sync.WaitGroup for r := 0; r < 16; r++ { wg.Add(1) go func() { for _, e := range testAddrs { - c.AddLinkAddress(e.addr, e.linkAddr) + c.add(e.addr, e.linkAddr) } wg.Done() }() @@ -154,12 +156,12 @@ func TestCacheConcurrent(t *testing.T) { // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, linkRes, "", nil, nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } e = testAddrs[0] @@ -171,38 +173,40 @@ func TestCacheConcurrent(t *testing.T) { } func TestCacheAgeLimit(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3, linkRes) e := testAddrs[0] - c.AddLinkAddress(e.addr, e.linkAddr) + c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - _, _, err := c.get(e.addr, linkRes, "", nil, nil) + _, _, err := c.get(e.addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = ErrWouldBlock", e.addr, err) + t.Errorf("got c.get(%s, '', nil) = %s, want = ErrWouldBlock", e.addr, err) } } func TestCacheReplace(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) + var c linkAddrCache + c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) e := testAddrs[0] l2 := e.linkAddr + "2" - c.AddLinkAddress(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil, nil) + c.add(e.addr, e.linkAddr) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } - c.AddLinkAddress(e.addr, l2) - got, _, err = c.get(e.addr, nil, "", nil, nil) + c.add(e.addr, l2) + got, _, err = c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != l2 { - t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, l2) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, l2) } } @@ -213,34 +217,36 @@ func TestCacheResolution(t *testing.T) { // // Using a large resolution timeout decreases the probability of experiencing // this race condition and does not affect how long this test takes to run. - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1, linkRes) for i, ta := range testAddrs { - got, err := getBlocking(c, ta.addr, linkRes) + got, err := getBlocking(&c, ta.addr) if err != nil { - t.Errorf("check %d, getBlocking(_, %s, _): %s", i, ta.addr, err) + t.Errorf("check %d, getBlocking(_, %s): %s", i, ta.addr, err) } if got != ta.linkAddr { - t.Errorf("check %d, got getBlocking(_, %s, _) = %s, want = %s", i, ta.addr, got, ta.linkAddr) + t.Errorf("check %d, got getBlocking(_, %s) = %s, want = %s", i, ta.addr, got, ta.linkAddr) } } // Check that after resolved, address stays in the cache and never returns WouldBlock. for i := 0; i < 10; i++ { e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, linkRes, "", nil, nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } } } func TestCacheResolutionFailed(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5, linkRes) var requestCount uint32 linkRes.onLinkAddressRequest = func() { @@ -249,20 +255,20 @@ func TestCacheResolutionFailed(t *testing.T) { // First, sanity check that resolution is working... e := testAddrs[0] - got, err := getBlocking(c, e.addr, linkRes) + got, err := getBlocking(&c, e.addr) if err != nil { - t.Errorf("getBlocking(_, %s, _): %s", e.addr, err) + t.Errorf("getBlocking(_, %s): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got getBlocking(_, %s, _) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got getBlocking(_, %s) = %s, want = %s", e.addr, got, e.linkAddr) } before := atomic.LoadUint32(&requestCount) e.addr += "2" - a, err := getBlocking(c, e.addr, linkRes) + a, err := getBlocking(&c, e.addr) if _, ok := err.(*tcpip.ErrTimeout); !ok { - t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) + t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { @@ -273,12 +279,13 @@ func TestCacheResolutionFailed(t *testing.T) { func TestCacheResolutionTimeout(t *testing.T) { resolverDelay := 500 * time.Millisecond expiration := resolverDelay / 10 - c := newLinkAddrCache(newEmptyNIC(), expiration, 1*time.Millisecond, 3) - linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c, delay: resolverDelay} + c.init(newEmptyNIC(), expiration, 1*time.Millisecond, 3, linkRes) e := testAddrs[0] - a, err := getBlocking(c, e.addr, linkRes) + a, err := getBlocking(&c, e.addr) if _, ok := err.(*tcpip.ErrTimeout); !ok { - t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) + t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 64383bc7c..0238605af 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -2796,14 +2796,8 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN NIC: nicID, }}) - if useNeighborCache { - if err := s.AddStaticNeighbor(nicID, llAddr3, linkAddr3); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) - } - } else { - if err := s.AddLinkAddress(nicID, llAddr3, linkAddr3); err != nil { - t.Fatalf("s.AddLinkAddress(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) - } + if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err) } return ndpDisp, e, s } diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 88a3ff776..7e3132058 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -42,11 +42,10 @@ type NeighborStats struct { // 2. Static entries are explicitly added by a user and have no expiration. // Their state is always Static. The amount of static entries stored in the // cache is unbounded. -// -// neighborCache implements NUDHandler. type neighborCache struct { - nic *NIC - state *NUDState + nic *NIC + state *NUDState + linkRes LinkAddressResolver // mu protects the fields below. mu sync.RWMutex @@ -62,8 +61,6 @@ type neighborCache struct { } } -var _ NUDHandler = (*neighborCache)(nil) - // getOrCreateEntry retrieves a cache entry associated with addr. The // returned entry is always refreshed in the cache (it is reachable via the // map, and its place is bumped in LRU). @@ -73,7 +70,7 @@ var _ NUDHandler = (*neighborCache)(nil) // reset to state incomplete, and returned. If no matching entry exists and the // cache is not full, a new entry with state incomplete is allocated and // returned. -func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { +func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address) *neighborEntry { n.mu.Lock() defer n.mu.Unlock() @@ -89,7 +86,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA // The entry that needs to be created must be dynamic since all static // entries are directly added to the cache via addStaticEntry. - entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes) + entry := newNeighborEntry(n, remoteAddr, n.state) if n.dynamic.count == neighborCacheSize { e := n.dynamic.lru.Back() e.mu.Lock() @@ -126,8 +123,8 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA // packet prompting NUD/link address resolution. // // TODO(gvisor.dev/issue/5151): Don't return the neighbor entry. -func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) { - entry := n.getOrCreateEntry(remoteAddr, linkRes) +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) { + entry := n.getOrCreateEntry(remoteAddr) entry.mu.Lock() defer entry.mu.Unlock() @@ -206,7 +203,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd entry.mu.Unlock() } - n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) + n.cache[addr] = newStaticNeighborEntry(n, addr, linkAddr, n.state) } // removeEntry removes a dynamic or static entry by address from the neighbor @@ -263,27 +260,45 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { n.state.SetConfig(config) } -// HandleProbe implements NUDHandler.HandleProbe by following the logic defined -// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled -// by the caller. -func (n *neighborCache) HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { - entry := n.getOrCreateEntry(remoteAddr, linkRes) +var _ neighborTable = (*neighborCache)(nil) + +func (n *neighborCache) neighbors() ([]NeighborEntry, tcpip.Error) { + return n.entries(), nil +} + +func (n *neighborCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { + entry, ch, err := n.entry(addr, localAddr, onResolve) + return entry.LinkAddr, ch, err +} + +func (n *neighborCache) remove(addr tcpip.Address) tcpip.Error { + if !n.removeEntry(addr) { + return &tcpip.ErrBadAddress{} + } + + return nil +} + +func (n *neighborCache) removeAll() tcpip.Error { + n.clear() + return nil +} + +// handleProbe handles a neighbor probe as defined by RFC 4861 section 7.2.3. +// +// Validation of the probe is expected to be handled by the caller. +func (n *neighborCache) handleProbe(remoteAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + entry := n.getOrCreateEntry(remoteAddr) entry.mu.Lock() entry.handleProbeLocked(remoteLinkAddr) entry.mu.Unlock() } -// HandleConfirmation implements NUDHandler.HandleConfirmation by following the -// logic defined in RFC 4861 section 7.2.5. +// handleConfirmation handles a neighbor confirmation as defined by +// RFC 4861 section 7.2.5. // -// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other -// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol -// should be deployed where preventing access to the broadcast segment might -// not be possible. SEND uses RSA key pairs to produce cryptographically -// generated addresses, as defined in RFC 3972, Cryptographically Generated -// Addresses (CGA). This ensures that the claimed source of an NDP message is -// the owner of the claimed address. -func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { +// Validation of the confirmation is expected to be handled by the caller. +func (n *neighborCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { n.mu.RLock() entry, ok := n.cache[addr] n.mu.RUnlock() @@ -309,3 +324,12 @@ func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { entry.mu.Unlock() } } + +func (n *neighborCache) nudConfig() (NUDConfigurations, tcpip.Error) { + return n.config(), nil +} + +func (n *neighborCache) setNUDConfig(c NUDConfigurations) tcpip.Error { + n.setConfig(c) + return nil +} diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 2870e4f66..b489b5e08 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -76,10 +76,15 @@ func entryDiffOptsWithSort() []cmp.Option { })) } -func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { +func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *testNeighborResolver { config.resetInvalidFields() rng := rand.New(rand.NewSource(time.Now().UnixNano())) - neigh := &neighborCache{ + linkRes := &testNeighborResolver{ + clock: clock, + entries: newTestEntryStore(), + delay: typicalLatency, + } + linkRes.neigh = &neighborCache{ nic: &NIC{ stack: &Stack{ clock: clock, @@ -88,11 +93,11 @@ func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock id: 1, stats: makeNICStats(), }, - state: NewNUDState(config, rng), - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + state: NewNUDState(config, rng), + linkRes: linkRes, + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), } - neigh.nic.neigh = neigh - return neigh + return linkRes } // testEntryStore contains a set of IP to NeighborEntry mappings. @@ -194,7 +199,7 @@ type testNeighborResolver struct { var _ LinkAddressResolver = (*testNeighborResolver)(nil) -func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { +func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error { if !r.dropReplies { // Delay handling the request to emulate network latency. r.clock.AfterFunc(r.delay, func() { @@ -212,7 +217,7 @@ func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ // fakeRequest emulates handling a response for a link address request. func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) { if entry, ok := r.entries.entryByAddr(addr); ok { - r.neigh.HandleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ + r.neigh.handleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, @@ -242,10 +247,10 @@ func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) + linkRes := newTestNeighborResolver(&nudDisp, c, clock) - if got, want := neigh.config(), c; got != want { - t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + if got, want := linkRes.neigh.config(), c; got != want { + t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) } // No events should have been dispatched. @@ -260,14 +265,14 @@ func TestNeighborCacheSetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) + linkRes := newTestNeighborResolver(&nudDisp, c, clock) c.MinRandomFactor = 1 c.MaxRandomFactor = 1 - neigh.setConfig(c) + linkRes.neigh.setConfig(c) - if got, want := neigh.config(), c; got != want { - t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + if got, want := linkRes.neigh.config(), c; got != want { + t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) } // No events should have been dispatched. @@ -282,22 +287,15 @@ func TestNeighborCacheEntry(t *testing.T) { c := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, c, clock) - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -329,8 +327,8 @@ func TestNeighborCacheEntry(t *testing.T) { t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + if _, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err != nil { + t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } // No more events should have been dispatched. @@ -346,23 +344,16 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -394,7 +385,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } - neigh.removeEntry(entry.Addr) + linkRes.neigh.removeEntry(entry.Addr) { wantEvents := []testEntryEventInfo{ @@ -417,17 +408,15 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } { - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } } } type testContext struct { clock *faketime.ManualClock - neigh *neighborCache - store *testEntryStore linkRes *testNeighborResolver nudDisp *testNUDDispatcher } @@ -435,19 +424,10 @@ type testContext struct { func newTestContext(c NUDConfigurations) testContext { nudDisp := &testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(nudDisp, c, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(nudDisp, c, clock) return testContext{ clock: clock, - neigh: neigh, - store: store, linkRes: linkRes, nudDisp: nudDisp, } @@ -461,17 +441,17 @@ type overflowOptions struct { func (c *testContext) overflowCache(opts overflowOptions) error { // Fill the neighbor cache to capacity to verify the LRU eviction strategy is // working properly after the entry removal. - for i := opts.startAtEntryIndex; i < c.store.size(); i++ { + for i := opts.startAtEntryIndex; i < c.linkRes.entries.size(); i++ { // Add a new entry - entry, ok := c.store.entry(i) + entry, ok := c.linkRes.entries.entry(i) if !ok { - return fmt.Errorf("c.store.entry(%d) not found", i) + return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + return fmt.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } - c.clock.Advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) var wantEvents []testEntryEventInfo @@ -479,9 +459,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error { // LRU eviction strategy. Note that the number of static entries should not // affect the total number of dynamic entries that can be added. if i >= neighborCacheSize+opts.startAtEntryIndex { - removedEntry, ok := c.store.entry(i - neighborCacheSize) + removedEntry, ok := c.linkRes.entries.entry(i - neighborCacheSize) if !ok { - return fmt.Errorf("store.entry(%d) not found", i-neighborCacheSize) + return fmt.Errorf("linkRes.entries.entry(%d) not found", i-neighborCacheSize) } wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestRemoved, @@ -524,10 +504,10 @@ func (c *testContext) overflowCache(opts overflowOptions) error { // by entries() is nondeterministic, so entries have to be sorted before // comparison. wantUnsortedEntries := opts.wantStaticEntries - for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ { - entry, ok := c.store.entry(i) + for i := c.linkRes.entries.size() - neighborCacheSize; i < c.linkRes.entries.size(); i++ { + entry, ok := c.linkRes.entries.entry(i) if !ok { - return fmt.Errorf("c.store.entry(%d) not found", i) + return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) } wantEntry := NeighborEntry{ Addr: entry.Addr, @@ -537,7 +517,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(wantUnsortedEntries, c.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, c.linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } @@ -581,15 +561,15 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { c := newTestContext(config) // Add a dynamic entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } - c.clock.Advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -618,7 +598,7 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { } // Remove the entry - c.neigh.removeEntry(entry.Addr) + c.linkRes.neigh.removeEntry(entry.Addr) { wantEvents := []testEntryEventInfo{ @@ -657,12 +637,12 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { c := newTestContext(config) // Add a static entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -683,7 +663,7 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { } // Remove the static entry that was just added - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) // No more events should have been dispatched. c.nudDisp.mu.Lock() @@ -701,12 +681,12 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) c := newTestContext(config) // Add a static entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -728,7 +708,7 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) // Add a duplicate entry with a different link address staticLinkAddr += "duplicate" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) { wantEvents := []testEntryEventInfo{ { @@ -763,12 +743,12 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { c := newTestContext(config) // Add a static entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -789,7 +769,7 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { } // Remove the static entry that was just added - c.neigh.removeEntry(entry.Addr) + c.linkRes.neigh.removeEntry(entry.Addr) { wantEvents := []testEntryEventInfo{ { @@ -833,13 +813,13 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { c := newTestContext(config) // Add a dynamic entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -871,7 +851,7 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { // Override the entry with a static one using the same address staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) { wantEvents := []testEntryEventInfo{ { @@ -926,14 +906,14 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { c := newTestContext(config) - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) - e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + c.linkRes.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) + e, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -941,7 +921,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { State: Static, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ @@ -983,23 +963,16 @@ func TestNeighborCacheClear(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) // Add a dynamic entry. - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -1031,7 +1004,7 @@ func TestNeighborCacheClear(t *testing.T) { } // Add a static entry. - neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) + linkRes.neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) { wantEvents := []testEntryEventInfo{ @@ -1055,7 +1028,7 @@ func TestNeighborCacheClear(t *testing.T) { } // Clear should remove both dynamic and static entries. - neigh.clear() + linkRes.neigh.clear() // Remove events dispatched from clear() have no deterministic order so they // need to be sorted beforehand. @@ -1099,13 +1072,13 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { c := newTestContext(config) // Add a dynamic entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -1136,7 +1109,7 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { } // Clear the cache. - c.neigh.clear() + c.linkRes.neigh.clear() { wantEvents := []testEntryEventInfo{ { @@ -1175,18 +1148,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) - frequentlyUsedEntry, ok := store.entry(0) + frequentlyUsedEntry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } // The following logic is very similar to overflowCache, but @@ -1194,23 +1160,23 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // Fill the neighbor cache to capacity for i := 0; i < neighborCacheSize; i++ { - entry, ok := store.entry(i) + entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("store.entry(%d) not found", i) + t.Fatalf("linkRes.entries.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { @@ -1241,38 +1207,38 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Keep adding more entries - for i := neighborCacheSize; i < store.size(); i++ { + for i := neighborCacheSize; i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { - if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", frequentlyUsedEntry.Addr, err) + if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { + t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", frequentlyUsedEntry.Addr, err) } } - entry, ok := store.entry(i) + entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("store.entry(%d) not found", i) + t.Fatalf("linkRes.entries.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy - removedEntry, ok := store.entry(i - neighborCacheSize + 1) + removedEntry, ok := linkRes.entries.entry(i - neighborCacheSize + 1) if !ok { - t.Fatalf("store.entry(%d) not found", i-neighborCacheSize+1) + t.Fatalf("linkRes.entries.entry(%d) not found", i-neighborCacheSize+1) } wantEvents := []testEntryEventInfo{ { @@ -1322,10 +1288,10 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { }, } - for i := store.size() - neighborCacheSize + 1; i < store.size(); i++ { - entry, ok := store.entry(i) + for i := linkRes.entries.size() - neighborCacheSize + 1; i < linkRes.entries.size(); i++ { + entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("store.entry(%d) not found", i) + t.Fatalf("linkRes.entries.entry(%d) not found", i) } wantEntry := NeighborEntry{ Addr: entry.Addr, @@ -1335,7 +1301,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } @@ -1354,26 +1320,19 @@ func TestNeighborCacheConcurrent(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) - storeEntries := store.entries() + storeEntries := linkRes.entries.entries() for _, entry := range storeEntries { var wg sync.WaitGroup for r := 0; r < concurrentProcesses; r++ { wg.Add(1) go func(entry NeighborEntry) { defer wg.Done() - switch e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err.(type) { + switch e, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err.(type) { case nil, *tcpip.ErrWouldBlock: default: - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{}) + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{}) } }(entry) } @@ -1391,10 +1350,10 @@ func TestNeighborCacheConcurrent(t *testing.T) { // The order of entries reported by entries() is nondeterministic, so entries // have to be sorted before comparison. var wantUnsortedEntries []NeighborEntry - for i := store.size() - neighborCacheSize; i < store.size(); i++ { - entry, ok := store.entry(i) + for i := linkRes.entries.size() - neighborCacheSize; i < linkRes.entries.size(); i++ { + entry, ok := linkRes.entries.entry(i) if !ok { - t.Errorf("store.entry(%d) not found", i) + t.Errorf("linkRes.entries.entry(%d) not found", i) } wantEntry := NeighborEntry{ Addr: entry.Addr, @@ -1404,7 +1363,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } } @@ -1414,41 +1373,34 @@ func TestNeighborCacheReplace(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) // Add an entry - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } // Verify the entry exists { - e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) } if t.Failed() { t.FailNow() @@ -1459,21 +1411,21 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } // Notify of a link address change var updatedLinkAddr tcpip.LinkAddress { - entry, ok := store.entry(1) + entry, ok := linkRes.entries.entry(1) if !ok { - t.Fatal("store.entry(1) not found") + t.Fatal("linkRes.entries.entry(1) not found") } updatedLinkAddr = entry.LinkAddr } - store.set(0, updatedLinkAddr) - neigh.HandleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ + linkRes.entries.set(0, updatedLinkAddr) + linkRes.neigh.handleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, @@ -1483,9 +1435,9 @@ func TestNeighborCacheReplace(t *testing.T) { // // Verify the entry's new link address and the new state. { - e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Fatalf("neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1493,17 +1445,17 @@ func TestNeighborCacheReplace(t *testing.T) { State: Delay, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } // Verify that the neighbor is now reachable. { - e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1511,7 +1463,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } } @@ -1521,46 +1473,39 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() + linkRes := newTestNeighborResolver(&nudDisp, config, clock) var requestCount uint32 - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - onLinkAddressRequest: func() { - atomic.AddUint32(&requestCount, 1) - }, + linkRes.onLinkAddressRequest = func() { + atomic.AddUint32(&requestCount, 1) } - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } // First, sanity check that resolution is working { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } - got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + got, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1568,7 +1513,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { State: Reachable, } if diff := cmp.Diff(want, got, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) + t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } // Verify address resolution fails for an unknown address. @@ -1576,24 +1521,24 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { entry.Addr += "2" { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } - maxAttempts := neigh.config().MaxUnicastProbes + maxAttempts := linkRes.neigh.config().MaxUnicastProbes if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want { t.Errorf("got link address request count = %d, want = %d", got, want) } @@ -1607,27 +1552,22 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { config.RetransmitTimer = time.Millisecond // small enough to cause timeout clock := faketime.NewManualClock() - neigh := newTestNeighborCache(nil, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: time.Minute, // large enough to cause timeout - } + linkRes := newTestNeighborResolver(nil, config, clock) + // large enough to cause timeout + linkRes.delay = time.Minute - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1635,7 +1575,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } @@ -1644,31 +1584,24 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { func TestNeighborCacheRetryResolution(t *testing.T) { config := DefaultNUDConfigurations() clock := faketime.NewManualClock() - neigh := newTestNeighborCache(nil, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - // Simulate a faulty link. - dropReplies: true, - } + linkRes := newTestNeighborResolver(nil, config, clock) + // Simulate a faulty link. + linkRes.dropReplies = true - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } // Perform address resolution with a faulty link, which will fail. { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1676,7 +1609,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) { select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } @@ -1688,20 +1621,20 @@ func TestNeighborCacheRetryResolution(t *testing.T) { State: Failed, }, } - if diff := cmp.Diff(neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) } // Retry address resolution with a working link. linkRes.dropReplies = false { - incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + incompleteEntry, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } if incompleteEntry.State != Incomplete { t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) @@ -1713,9 +1646,9 @@ func TestNeighborCacheRetryResolution(t *testing.T) { if !ok { t.Fatal("expected successful address resolution") } - reachableEntry, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + reachableEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Fatalf("neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) + t.Fatalf("linkRes.neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) } if reachableEntry.Addr != entry.Addr { t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr) @@ -1727,7 +1660,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) { t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String()) } default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } } @@ -1736,42 +1669,36 @@ func BenchmarkCacheClear(b *testing.B) { b.StopTimer() config := DefaultNUDConfigurations() clock := &tcpip.StdClock{} - neigh := newTestNeighborCache(nil, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: 0, - } + linkRes := newTestNeighborResolver(nil, config, clock) + linkRes.delay = 0 // Clear for every possible size of the cache for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { // Fill the neighbor cache to capacity. for i := 0; i < cacheSize; i++ { - entry, ok := store.entry(i) + entry, ok := linkRes.entries.entry(i) if !ok { - b.Fatalf("store.entry(%d) not found", i) + b.Fatalf("linkRes.entries.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + b.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } select { case <-ch: default: - b.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } b.StartTimer() - neigh.clear() + linkRes.neigh.clear() b.StopTimer() } } diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 53ac9bb6e..b05f96d4f 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -77,11 +77,7 @@ const ( type neighborEntry struct { neighborEntryEntry - nic *NIC - - // linkRes provides the functionality to send reachability probes, used in - // Neighbor Unreachability Detection. - linkRes LinkAddressResolver + cache *neighborCache // nudState points to the Neighbor Unreachability Detection configuration. nudState *NUDState @@ -106,10 +102,9 @@ type neighborEntry struct { // state, Unknown. Transition out of Unknown by calling either // `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created // neighborEntry. -func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { +func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState *NUDState) *neighborEntry { return &neighborEntry{ - nic: nic, - linkRes: linkRes, + cache: cache, nudState: nudState, neigh: NeighborEntry{ Addr: remoteAddr, @@ -121,18 +116,18 @@ func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, li // newStaticNeighborEntry creates a neighbor cache entry starting at the // Static state. The entry can only transition out of Static by directly // calling `setStateLocked`. -func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { +func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { entry := NeighborEntry{ Addr: addr, LinkAddr: linkAddr, State: Static, - UpdatedAtNanos: nic.stack.clock.NowNanoseconds(), + UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(), } - if nic.stack.nudDisp != nil { - nic.stack.nudDisp.OnNeighborAdded(nic.id, entry) + if nudDisp := cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborAdded(cache.nic.id, entry) } return &neighborEntry{ - nic: nic, + cache: cache, nudState: state, neigh: entry, } @@ -158,7 +153,7 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { // is resolved (which ends up obtaining the entry's lock) while holding the // link resolution queue's lock. Dequeuing packets in a new goroutine avoids // a lock ordering violation. - go e.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) + go e.cache.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) } } @@ -167,8 +162,8 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchAddEventLocked() { - if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborAdded(e.nic.id, e.neigh) + if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborAdded(e.cache.nic.id, e.neigh) } } @@ -177,8 +172,8 @@ func (e *neighborEntry) dispatchAddEventLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchChangeEventLocked() { - if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborChanged(e.nic.id, e.neigh) + if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborChanged(e.cache.nic.id, e.neigh) } } @@ -187,8 +182,8 @@ func (e *neighborEntry) dispatchChangeEventLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchRemoveEventLocked() { - if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) + if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborRemoved(e.cache.nic.id, e.neigh) } } @@ -206,7 +201,7 @@ func (e *neighborEntry) cancelJobLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) removeLocked() { - e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() e.dispatchRemoveEventLocked() e.cancelJobLocked() e.notifyCompletionLocked(false /* succeeded */) @@ -222,7 +217,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { prev := e.neigh.State e.neigh.State = next - e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() config := e.nudState.Config() switch next { @@ -230,14 +225,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev)) case Reachable: - e.job = e.nic.stack.newJob(&e.mu, func() { + e.job = e.cache.nic.stack.newJob(&e.mu, func() { e.setStateLocked(Stale) e.dispatchChangeEventLocked() }) e.job.Schedule(e.nudState.ReachableTime()) case Delay: - e.job = e.nic.stack.newJob(&e.mu, func() { + e.job = e.cache.nic.stack.newJob(&e.mu, func() { e.setStateLocked(Probe) e.dispatchChangeEventLocked() }) @@ -254,14 +249,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr, e.nic); err != nil { + if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return } retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(config.RetransmitTimer) } @@ -269,7 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // for finishing the state transition. This is necessary to avoid // deadlock where sending and processing probes are done synchronously, // such as loopback and integration tests. - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(immediateDuration) case Failed: @@ -292,12 +287,12 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { case Failed: - e.nic.stats.Neighbor.FailedEntryLookups.Increment() + e.cache.nic.stats.Neighbor.FailedEntryLookups.Increment() fallthrough case Unknown: e.neigh.State = Incomplete - e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() e.dispatchAddEventLocked() @@ -340,7 +335,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // address SHOULD be placed in the IP Source Address of the outgoing // solicitation. // - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, "", e.nic); err != nil { + if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, ""); err != nil { // There is no need to log the error here; the NUD implementation may // assume a working link. A valid link should be the responsibility of // the NIC/stack.LinkEndpoint. @@ -350,7 +345,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { } retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe) e.job.Schedule(config.RetransmitTimer) } @@ -358,7 +353,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // for finishing the state transition. This is necessary to avoid // deadlock where sending and processing probes are done synchronously, // such as loopback and integration tests. - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe) e.job.Schedule(immediateDuration) case Stale: @@ -504,7 +499,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // // TODO(gvisor.dev/issue/4085): Remove the special casing we do for IPv6 // here. - ep, ok := e.nic.networkEndpoints[header.IPv6ProtocolNumber] + ep, ok := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber] if !ok { panic(fmt.Sprintf("have a neighbor entry for an IPv6 router but no IPv6 network endpoint")) } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 140b8ca00..57cfbdb8b 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -193,7 +193,7 @@ func (p entryTestProbeInfo) String() string { // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts // to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { +func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { p := entryTestProbeInfo{ RemoteAddress: targetAddr, RemoteLinkAddress: linkAddr, @@ -230,22 +230,30 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e }, stats: makeNICStats(), } + netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ - header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil), + header.IPv6ProtocolNumber: netEP, } rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) - linkRes := entryTestLinkResolver{} - entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes) - + var linkRes entryTestLinkResolver // Stub out the neighbor cache to verify deletion from the cache. - nic.neigh = &neighborCache{ - nic: &nic, - state: nudState, - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + neigh := &neighborCache{ + nic: &nic, + state: nudState, + linkRes: &linkRes, + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + l := linkResolver{ + resolver: &linkRes, + neighborTable: neigh, + } + entry := newNeighborEntry(neigh, entryTestAddr1 /* remoteAddr */, nudState) + neigh.cache[entryTestAddr1] = entry + nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]linkResolver{ + header.IPv6ProtocolNumber: l, } - nic.neigh.cache[entryTestAddr1] = entry return entry, &disp, &linkRes, clock } @@ -835,7 +843,7 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) - ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) + ipv6EP := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index e56a624fe..41a489047 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "math/rand" "reflect" "sync/atomic" @@ -25,8 +24,37 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) +type neighborTable interface { + neighbors() ([]NeighborEntry, tcpip.Error) + addStaticEntry(tcpip.Address, tcpip.LinkAddress) + get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) + remove(tcpip.Address) tcpip.Error + removeAll() tcpip.Error + + handleProbe(tcpip.Address, tcpip.LinkAddress) + handleConfirmation(tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) + handleUpperLevelConfirmation(tcpip.Address) + + nudConfig() (NUDConfigurations, tcpip.Error) + setNUDConfig(NUDConfigurations) tcpip.Error +} + var _ NetworkInterface = (*NIC)(nil) +type linkResolver struct { + resolver LinkAddressResolver + + neighborTable neighborTable +} + +func (l *linkResolver) getNeighborLinkAddress(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { + return l.neighborTable.get(addr, localAddr, onResolve) +} + +func (l *linkResolver) confirmReachable(addr tcpip.Address) { + l.neighborTable.handleUpperLevelConfirmation(addr) +} + // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { @@ -38,11 +66,11 @@ type NIC struct { context NICContext stats NICStats - neigh *neighborCache // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. - networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint + networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint + linkAddrResolvers map[tcpip.NetworkProtocolNumber]linkResolver // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. // @@ -53,8 +81,6 @@ type NIC struct { // complete. linkResQueue packetsPendingLinkResolution - linkAddrCache *linkAddrCache - mu struct { sync.RWMutex spoofing bool @@ -133,35 +159,18 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic := &NIC{ LinkEndpoint: ep, - stack: stack, - id: id, - name: name, - context: ctx, - stats: makeNICStats(), - networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), + stack: stack, + id: id, + name: name, + context: ctx, + stats: makeNICStats(), + networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]linkResolver), } nic.linkResQueue.init(nic) - nic.linkAddrCache = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) - // Check for Neighbor Unreachability Detection support. - var nud NUDHandler - if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 && stack.useNeighborCache { - rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds())) - nic.neigh = &neighborCache{ - nic: nic, - state: NewNUDState(stack.nudConfigs, rng), - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), - } - - // An interface value that holds a nil pointer but non-nil type is not the - // same as the nil interface. Because of this, nud must only be assignd if - // nic.neigh is non-nil since a nil reference to a neighborCache is not - // valid. - // - // See https://golang.org/doc/faq#nil_error for more information. - nud = nic.neigh - } + resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0 // Register supported packet and network endpoint protocols. for _, netProto := range header.Ethertypes { @@ -170,7 +179,32 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC for _, netProto := range stack.networkProtocols { netNum := netProto.Number() nic.mu.packetEPs[netNum] = new(packetEndpointList) - nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, nic.linkAddrCache, nud, nic) + + netEP := netProto.NewEndpoint(nic, nic) + nic.networkEndpoints[netNum] = netEP + + if resolutionRequired { + if r, ok := netEP.(LinkAddressResolver); ok { + l := linkResolver{ + resolver: r, + } + + if stack.useNeighborCache { + l.neighborTable = &neighborCache{ + nic: nic, + state: NewNUDState(stack.nudConfigs, stack.randomGenerator), + linkRes: r, + + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + } else { + cache := new(linkAddrCache) + cache.init(nic, ageLimit, resolutionTimeout, resolutionAttempts, r) + l.neighborTable = cache + } + nic.linkAddrResolvers[r.LinkAddressProtocol()] = l + } + } } nic.LinkEndpoint.Attach(nic) @@ -223,18 +257,19 @@ func (n *NIC) disableLocked() { for _, ep := range n.networkEndpoints { ep.Disable() - } - // Clear the neighbour table (including static entries) as we cannot guarantee - // that the current neighbour table will be valid when the NIC is enabled - // again. - // - // This matches linux's behaviour at the time of writing: - // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371 - switch err := n.clearNeighbors(); err.(type) { - case nil, *tcpip.ErrNotSupported: - default: - panic(fmt.Sprintf("n.clearNeighbors(): %s", err)) + // Clear the neighbour table (including static entries) as we cannot + // guarantee that the current neighbour table will be valid when the NIC is + // enabled again. + // + // This matches linux's behaviour at the time of writing: + // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371 + netProto := ep.NetworkProtocolNumber() + switch err := n.clearNeighbors(netProto); err.(type) { + case nil, *tcpip.ErrNotSupported: + default: + panic(fmt.Sprintf("n.clearNeighbors(%d): %s", netProto, err)) + } } if !n.setEnabled(false) { @@ -587,56 +622,52 @@ func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error { return &tcpip.ErrBadLocalAddress{} } -func (n *NIC) confirmReachable(addr tcpip.Address) { - if n := n.neigh; n != nil { - n.handleUpperLevelConfirmation(addr) +func (n *NIC) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error { + linkRes, ok := n.linkAddrResolvers[protocol] + if !ok { + return &tcpip.ErrNotSupported{} } -} -func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - if n.neigh != nil { - entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve) - return entry.LinkAddr, ch, err + if linkAddr, ok := linkRes.resolver.ResolveStaticAddress(addr); ok { + onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true}) + return nil } - return n.linkAddrCache.get(addr, linkRes, localAddr, n, onResolve) + _, _, err := linkRes.getNeighborLinkAddress(addr, localAddr, onResolve) + return err } -func (n *NIC) neighbors() ([]NeighborEntry, tcpip.Error) { - if n.neigh == nil { - return nil, &tcpip.ErrNotSupported{} +func (n *NIC) neighbors(protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.neighbors() } - return n.neigh.entries(), nil + return nil, &tcpip.ErrNotSupported{} } -func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) tcpip.Error { - if n.neigh == nil { - return &tcpip.ErrNotSupported{} +func (n *NIC) addStaticNeighbor(addr tcpip.Address, protocol tcpip.NetworkProtocolNumber, linkAddress tcpip.LinkAddress) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + linkRes.neighborTable.addStaticEntry(addr, linkAddress) + return nil } - n.neigh.addStaticEntry(addr, linkAddress) - return nil + return &tcpip.ErrNotSupported{} } -func (n *NIC) removeNeighbor(addr tcpip.Address) tcpip.Error { - if n.neigh == nil { - return &tcpip.ErrNotSupported{} +func (n *NIC) removeNeighbor(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.remove(addr) } - if !n.neigh.removeEntry(addr) { - return &tcpip.ErrBadAddress{} - } - return nil + return &tcpip.ErrNotSupported{} } -func (n *NIC) clearNeighbors() tcpip.Error { - if n.neigh == nil { - return &tcpip.ErrNotSupported{} +func (n *NIC) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.removeAll() } - n.neigh.clear() - return nil + return &tcpip.ErrNotSupported{} } // joinGroup adds a new endpoint for the given multicast address, if none @@ -880,9 +911,8 @@ func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt } } -// DeliverTransportControlPacket delivers control packets to the appropriate -// transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) { +// DeliverTransportError implements TransportDispatcher. +func (n *NIC) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return @@ -904,7 +934,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, pkt, id) { + if n.stack.demux.deliverError(n, net, trans, transErr, pkt, id) { return } } @@ -920,24 +950,25 @@ func (n *NIC) Name() string { } // nudConfigs gets the NUD configurations for n. -func (n *NIC) nudConfigs() (NUDConfigurations, tcpip.Error) { - if n.neigh == nil { - return NUDConfigurations{}, &tcpip.ErrNotSupported{} +func (n *NIC) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.nudConfig() } - return n.neigh.config(), nil + + return NUDConfigurations{}, &tcpip.ErrNotSupported{} } // setNUDConfigs sets the NUD configurations for n. // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (n *NIC) setNUDConfigs(c NUDConfigurations) tcpip.Error { - if n.neigh == nil { - return &tcpip.ErrNotSupported{} +func (n *NIC) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + c.resetInvalidFields() + return linkRes.neighborTable.setNUDConfig(c) } - c.resetInvalidFields() - n.neigh.setConfig(c) - return nil + + return &tcpip.ErrNotSupported{} } func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { @@ -973,3 +1004,23 @@ func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { n.mu.RUnlock() return n.Enabled() && ep.IsAssigned(spoofing) } + +// HandleNeighborProbe implements NetworkInterface. +func (n *NIC) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { + if l, ok := n.linkAddrResolvers[protocol]; ok { + l.neighborTable.handleProbe(addr, linkAddr) + return nil + } + + return &tcpip.ErrNotSupported{} +} + +// HandleNeighborConfirmation implements NetworkInterface. +func (n *NIC) HandleNeighborConfirmation(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) tcpip.Error { + if l, ok := n.linkAddrResolvers[protocol]; ok { + l.neighborTable.handleConfirmation(addr, linkAddr, flags) + return nil + } + + return &tcpip.ErrNotSupported{} +} diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 2f719fbe5..9992d6eb4 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -111,8 +111,6 @@ type testIPv6EndpointStats struct{} // IsNetworkEndpointStats implements stack.NetworkEndpointStats. func (*testIPv6EndpointStats) IsNetworkEndpointStats() {} -var _ LinkAddressResolver = (*testIPv6Protocol)(nil) - // We use this instead of ipv6.protocol because the ipv6 package depends on // the stack package which this test lives in, causing a cyclic dependency. type testIPv6Protocol struct{} @@ -139,7 +137,7 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) } // NewEndpoint implements NetworkProtocol.NewEndpoint. -func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint { +func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ TransportDispatcher) NetworkEndpoint { e := &testIPv6Endpoint{ nic: nic, protocol: p, @@ -169,24 +167,6 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo return 0, false, false } -// LinkAddressProtocol implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return header.IPv6ProtocolNumber -} - -// LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { - return nil -} - -// ResolveStaticAddress implements LinkAddressResolver. -func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if header.IsV6MulticastAddress(addr) { - return header.EthernetAddressFromMulticastIPv6Address(addr), true - } - return "", false -} - func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 77926e289..5a94e9ac6 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -161,21 +161,6 @@ type ReachabilityConfirmationFlags struct { IsRouter bool } -// NUDHandler communicates external events to the Neighbor Unreachability -// Detection state machine, which is implemented per-interface. This is used by -// network endpoints to inform the Neighbor Cache of probes and confirmations. -type NUDHandler interface { - // HandleProbe processes an incoming neighbor probe (e.g. ARP request or - // Neighbor Solicitation for ARP or NDP, respectively). Validation of the - // probe needs to be performed before calling this function since the - // Neighbor Cache doesn't have access to view the NIC's assigned addresses. - HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) - - // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP - // reply or Neighbor Advertisement for ARP or NDP, respectively). - HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) -} - // NUDConfigurations is the NUD configurations for the netstack. This is used // by the neighbor cache to operate the NUD state machine on each device in the // local network. diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index ebfd5eb45..e9acef6a2 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -19,7 +19,9 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -52,83 +54,146 @@ func (f *fakeRand) Float32() float32 { return f.num } -// TestSetNUDConfigurationFailsForBadNICID tests to make sure we get an error if -// we attempt to update NUD configurations using an invalid NICID. -func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The networking - // stack will only allocate neighbor caches if a protocol providing link - // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - UseNeighborCache: true, - }) +func TestNUDFunctions(t *testing.T) { + const nicID = 1 - // No NIC with ID 1 yet. - config := stack.NUDConfigurations{} - err := s.SetNUDConfigurations(1, config) - if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { - t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, &tcpip.ErrUnknownNICID{}) + tests := []struct { + name string + nicID tcpip.NICID + netProtoFactory []stack.NetworkProtocolFactory + extraLinkCapabilities stack.LinkEndpointCapabilities + expectedErr tcpip.Error + }{ + { + name: "Invalid NICID", + nicID: nicID + 1, + netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + extraLinkCapabilities: stack.CapabilityResolutionRequired, + expectedErr: &tcpip.ErrUnknownNICID{}, + }, + { + name: "No network protocol", + nicID: nicID, + expectedErr: &tcpip.ErrNotSupported{}, + }, + { + name: "With IPv6", + nicID: nicID, + netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + expectedErr: &tcpip.ErrNotSupported{}, + }, + { + name: "With resolution capability", + nicID: nicID, + extraLinkCapabilities: stack.CapabilityResolutionRequired, + expectedErr: &tcpip.ErrNotSupported{}, + }, + { + name: "With IPv6 and resolution capability", + nicID: nicID, + netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + extraLinkCapabilities: stack.CapabilityResolutionRequired, + }, } -} -// TestNUDConfigurationFailsForNotSupported tests to make sure we get a -// NotSupported error if we attempt to retrieve NUD configurations when the -// stack doesn't support NUD. -// -// The stack will report to not support NUD if a neighbor cache for a given NIC -// is not allocated. The networking stack will only allocate neighbor caches if -// a protocol providing link address resolution is specified (e.g. ARP, IPv6). -func TestNUDConfigurationFailsForNotSupported(t *testing.T) { - const nicID = 1 + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NUDConfigs: stack.DefaultNUDConfigurations(), + UseNeighborCache: true, + NetworkProtocols: test.netProtoFactory, + Clock: clock, + }) - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + e := channel.New(0, 0, linkAddr1) + e.LinkEPCapabilities &^= stack.CapabilityResolutionRequired + e.LinkEPCapabilities |= test.extraLinkCapabilities - s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), - UseNeighborCache: true, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - _, err := s.NUDConfigurations(nicID) - if _, ok := err.(*tcpip.ErrNotSupported); !ok { - t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, &tcpip.ErrNotSupported{}) - } -} + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } -// TestNUDConfigurationFailsForNotSupported tests to make sure we get a -// NotSupported error if we attempt to set NUD configurations when the stack -// doesn't support NUD. -// -// The stack will report to not support NUD if a neighbor cache for a given NIC -// is not allocated. The networking stack will only allocate neighbor caches if -// a protocol providing link address resolution is specified (e.g. ARP, IPv6). -func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) { - const nicID = 1 + configs := stack.DefaultNUDConfigurations() + configs.BaseReachableTime = time.Hour - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + { + err := s.SetNUDConfigurations(test.nicID, ipv6.ProtocolNumber, configs) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.SetNUDConfigurations(%d, %d, _) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } + } - s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), - UseNeighborCache: true, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + { + gotConfigs, err := s.NUDConfigurations(test.nicID, ipv6.ProtocolNumber) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.NUDConfigurations(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } else if test.expectedErr == nil { + if diff := cmp.Diff(configs, gotConfigs); diff != "" { + t.Errorf("got configs mismatch (-want +got):\n%s", diff) + } + } + } - config := stack.NUDConfigurations{} - err := s.SetNUDConfigurations(nicID, config) - if _, ok := err.(*tcpip.ErrNotSupported); !ok { - t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, &tcpip.ErrNotSupported{}) + for _, addr := range []tcpip.Address{llAddr1, llAddr2} { + { + err := s.AddStaticNeighbor(test.nicID, ipv6.ProtocolNumber, addr, linkAddr1) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.AddStaticNeighbor(%d, %d, %s, %s) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, addr, linkAddr1, diff) + } + } + } + + { + wantErr := test.expectedErr + for i := 0; i < 2; i++ { + { + err := s.RemoveNeighbor(test.nicID, ipv6.ProtocolNumber, llAddr1) + if diff := cmp.Diff(wantErr, err); diff != "" { + t.Errorf("s.RemoveNeighbor(%d, %d, '') error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } + } + + if test.expectedErr != nil { + break + } + + // Removing a neighbor that does not exist should give us a bad address + // error. + wantErr = &tcpip.ErrBadAddress{} + } + } + + { + neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } else if test.expectedErr == nil { + if diff := cmp.Diff( + []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + neighbors, + ); diff != "" { + t.Errorf("neighbors mismatch (-want +got):\n%s", diff) + } + } + } + + { + err := s.ClearNeighbors(test.nicID, ipv6.ProtocolNumber) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.ClearNeigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } else if test.expectedErr == nil { + if neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber); err != nil { + t.Errorf("s.Neighbors(%d, %d): %s", test.nicID, ipv6.ProtocolNumber, err) + } else if len(neighbors) != 0 { + t.Errorf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + } + } + } + }) } } -// TestDefaultNUDConfigurationIsValid verifies that calling -// resetInvalidFields() on the result of DefaultNUDConfigurations() does not -// change anything. DefaultNUDConfigurations() should return a valid -// NUDConfigurations. func TestDefaultNUDConfigurations(t *testing.T) { const nicID = 1 @@ -146,12 +211,12 @@ func TestDefaultNUDConfigurations(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - c, err := s.NUDConfigurations(nicID) + c, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got, want := c, stack.DefaultNUDConfigurations(); got != want { - t.Errorf("got stack.NUDConfigurations(%d) = %+v, want = %+v", nicID, got, want) + t.Errorf("got stack.NUDConfigurations(%d, %d) = %+v, want = %+v", nicID, ipv6.ProtocolNumber, got, want) } } @@ -201,9 +266,9 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.BaseReachableTime; got != test.want { t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want) @@ -258,9 +323,9 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MinRandomFactor; got != test.want { t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want) @@ -338,9 +403,9 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MaxRandomFactor; got != test.want { t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want) @@ -400,9 +465,9 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.RetransmitTimer; got != test.want { t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want) @@ -452,9 +517,9 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.DelayFirstProbeTime; got != test.want { t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want) @@ -504,9 +569,9 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MaxMulticastProbes; got != test.want { t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want) @@ -556,9 +621,9 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MaxUnicastProbes; got != test.want { t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9d4fc3e48..4f013b212 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -187,6 +187,12 @@ func (pk *PacketBuffer) Size() int { return pk.HeaderSize() + pk.Data.Size() } +// MemSize returns the estimation size of the pk in memory, including backing +// buffer data. +func (pk *PacketBuffer) MemSize() int { + return pk.HeaderSize() + pk.Data.MemSize() + packetBufferStructSize +} + // Views returns the underlying storage of the whole packet. func (pk *PacketBuffer) Views() []buffer.View { // Optimization for outbound packets that headers are in pk.header. diff --git a/pkg/tcpip/stack/packet_buffer_unsafe.go b/pkg/tcpip/stack/packet_buffer_unsafe.go new file mode 100644 index 000000000..ee3d47270 --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_unsafe.go @@ -0,0 +1,19 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import "unsafe" + +const packetBufferStructSize = int(unsafe.Sizeof(PacketBuffer{})) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 510da8689..d589f798d 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -49,31 +49,6 @@ type TransportEndpointID struct { RemoteAddress tcpip.Address } -// ControlType is the type of network control message. -type ControlType int - -// The following are the allowed values for ControlType values. -// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages. -const ( - // ControlAddressUnreachable indicates that an IPv6 packet did not reach its - // destination as the destination address was unreachable. - // - // This maps to the ICMPv6 Destination Ureachable Code 3 error; see - // RFC 4443 section 3.1 for more details. - ControlAddressUnreachable ControlType = iota - ControlNetworkUnreachable - // ControlNoRoute indicates that an IPv4 packet did not reach its destination - // because the destination host was unreachable. - // - // This maps to the ICMPv4 Destination Ureachable Code 1 error; see - // RFC 791's Destination Unreachable Message section (page 4) for more - // details. - ControlNoRoute - ControlPacketTooBig - ControlPortUnreachable - ControlUnknown -) - // NetworkPacketInfo holds information about a network layer packet. type NetworkPacketInfo struct { // LocalAddressBroadcast is true if the packet's local address is a broadcast @@ -81,6 +56,39 @@ type NetworkPacketInfo struct { LocalAddressBroadcast bool } +// TransportErrorKind enumerates error types that are handled by the transport +// layer. +type TransportErrorKind int + +const ( + // PacketTooBigTransportError indicates that a packet did not reach its + // destination because a link on the path to the destination had an MTU that + // was too small to carry the packet. + PacketTooBigTransportError TransportErrorKind = iota + + // DestinationHostUnreachableTransportError indicates that the destination + // host was unreachable. + DestinationHostUnreachableTransportError + + // DestinationPortUnreachableTransportError indicates that a packet reached + // the destination host, but the transport protocol was not active on the + // destination port. + DestinationPortUnreachableTransportError + + // DestinationNetworkUnreachableTransportError indicates that the destination + // network was unreachable. + DestinationNetworkUnreachableTransportError +) + +// TransportError is a marker interface for errors that may be handled by the +// transport layer. +type TransportError interface { + tcpip.SockErrorCause + + // Kind returns the type of the transport error. + Kind() TransportErrorKind +} + // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { @@ -93,10 +101,10 @@ type TransportEndpoint interface { // HandlePacket takes ownership of the packet. HandlePacket(TransportEndpointID, *PacketBuffer) - // HandleControlPacket is called by the stack when new control (e.g. - // ICMP) packets arrive to this transport endpoint. - // HandleControlPacket takes ownership of pkt. - HandleControlPacket(typ ControlType, extra uint32, pkt *PacketBuffer) + // HandleError is called when the transport endpoint receives an error. + // + // HandleError takes ownership of the packet buffer. + HandleError(TransportError, *PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This @@ -248,14 +256,11 @@ type TransportDispatcher interface { // DeliverTransportPacket takes ownership of the packet. DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition - // DeliverTransportControlPacket delivers control packets to the - // appropriate transport protocol endpoint. - // - // pkt.NetworkHeader must be set before calling - // DeliverTransportControlPacket. + // DeliverTransportError delivers an error to the appropriate transport + // endpoint. // - // DeliverTransportControlPacket takes ownership of pkt. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) + // DeliverTransportError takes ownership of the packet buffer. + DeliverTransportError(local, remote tcpip.Address, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ TransportError, _ *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -530,6 +535,17 @@ type NetworkInterface interface { // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) + + // HandleNeighborProbe processes an incoming neighbor probe (e.g. ARP + // request or NDP Neighbor Solicitation). + // + // HandleNeighborProbe assumes that the probe is valid for the network + // interface the probe was received on. + HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error + + // HandleNeighborConfirmation processes an incoming neighbor confirmation + // (e.g. ARP reply or NDP Neighbor Advertisement). + HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) tcpip.Error } // LinkResolvableNetworkEndpoint handles link resolution events. @@ -649,7 +665,7 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint + NewEndpoint(nic NetworkInterface, dispatcher TransportDispatcher) NetworkEndpoint // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -824,16 +840,12 @@ type InjectableLinkEndpoint interface { InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error } -// A LinkAddressResolver is an extension to a NetworkProtocol that -// can resolve link addresses. +// A LinkAddressResolver handles link address resolution for a network protocol. type LinkAddressResolver interface { // LinkAddressRequest sends a request for the link address of the target // address. The request is broadcasted on the local network if a remote link // address is not provided. - // - // The request is sent from the passed network interface. If the interface - // local address is unspecified, any interface local address may be used. - LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) tcpip.Error + LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the @@ -847,12 +859,6 @@ type LinkAddressResolver interface { LinkAddressProtocol() tcpip.NetworkProtocolNumber } -// A LinkAddressCache caches link addresses. -type LinkAddressCache interface { - // AddLinkAddress adds a link address to the cache. - AddLinkAddress(addr tcpip.Address, linkAddr tcpip.LinkAddress) -} - // RawFactory produces endpoints for writing various types of raw packets. type RawFactory interface { // NewUnassociatedEndpoint produces endpoints for writing packets not diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 4ae0f2a1a..bab55ce49 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -53,7 +53,7 @@ type Route struct { // linkRes is set if link address resolution is enabled for this protocol on // the route's NIC. - linkRes LinkAddressResolver + linkRes linkResolver } type routeInfo struct { @@ -174,7 +174,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA } if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { + if linkRes, ok := r.outgoingNIC.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes } } @@ -184,11 +184,11 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA return r } - if r.linkRes == nil { + if r.linkRes.resolver == nil { return r } - if linkAddr, ok := r.linkRes.ResolveStaticAddress(r.RemoteAddress); ok { + if linkAddr, ok := r.linkRes.resolver.ResolveStaticAddress(r.RemoteAddress); ok { r.ResolveWith(linkAddr) return r } @@ -362,7 +362,7 @@ func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteIn } afterResolveFields := fields - linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, r.linkRes, func(r LinkResolutionResult) { + linkAddr, ch, err := r.linkRes.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, func(r LinkResolutionResult) { if afterResolve != nil { if r.Success { afterResolveFields.RemoteLinkAddress = r.LinkAddress @@ -400,7 +400,7 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isResolutionRequiredRLocked() bool { - return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local() + return len(r.mu.remoteLinkAddress) == 0 && r.linkRes.resolver != nil && r.isValidForOutgoingRLocked() && !r.local() } func (r *Route) isValidForOutgoing() bool { @@ -528,5 +528,7 @@ func (r *Route) IsOutboundBroadcast() bool { // "Reachable" is defined as having full-duplex communication between the // local and remote ends of the route. func (r *Route) ConfirmReachable() { - r.outgoingNIC.confirmReachable(r.nextHop()) + if r.linkRes.resolver != nil { + r.linkRes.confirmReachable(r.nextHop()) + } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 119c4c505..57ad412a1 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -376,7 +376,6 @@ func (u *uniqueIDGenerator) UniqueID() uint64 { type Stack struct { transportProtocols map[tcpip.TransportProtocolNumber]*transportProtocolState networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol - linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver // rawFactory creates raw endpoints. If nil, raw endpoints are // disabled. It is set during Stack creation and is immutable. @@ -386,6 +385,15 @@ type Stack struct { stats tcpip.Stats + // LOCK ORDERING: mu > route.mu. + route struct { + mu struct { + sync.RWMutex + + table []tcpip.Route + } + } + mu sync.RWMutex nics map[tcpip.NICID]*NIC @@ -393,11 +401,6 @@ type Stack struct { cleanupEndpointsMu sync.Mutex cleanupEndpoints map[TransportEndpoint]struct{} - // route is the route table passed in by the user via SetRouteTable(), - // it is used by FindRoute() to build a route for a specific - // destination. - routeTable []tcpip.Route - *ports.PortManager // If not nil, then any new endpoints will have this probe function @@ -433,6 +436,8 @@ type Stack struct { // useNeighborCache indicates whether ARP and NDP packets should be handled // by the NIC's neighborCache instead of linkAddrCache. + // + // TODO(gvisor.dev/issue/4658): Remove this field. useNeighborCache bool // nudDisp is the NUD event dispatcher that is used to send the netstack @@ -499,13 +504,17 @@ type Options struct { // NUDConfigs is the default NUD configurations used by interfaces. NUDConfigs NUDConfigurations - // UseNeighborCache indicates whether ARP and NDP packets should be handled - // by the Neighbor Unreachability Detection (NUD) state machine. This flag - // also enables the APIs for inspecting and modifying the neighbor table via - // NUDDispatcher and the following Stack methods: Neighbors, RemoveNeighbor, - // and ClearNeighbors. + // UseNeighborCache is unused. + // + // TODO(gvisor.dev/issue/4658): Remove this field. UseNeighborCache bool + // UseLinkAddrCache indicates that the legacy link address cache should be + // used for link resolution. + // + // TODO(gvisor.dev/issue/4658): Remove this field. + UseLinkAddrCache bool + // NUDDisp is the NUD event dispatcher that an integrator can provide to // receive NUD related events. NUDDisp NUDDispatcher @@ -635,7 +644,6 @@ func New(opts Options) *Stack { s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), nics: make(map[tcpip.NICID]*NIC), cleanupEndpoints: make(map[TransportEndpoint]struct{}), PortManager: ports.NewPortManager(), @@ -646,7 +654,7 @@ func New(opts Options) *Stack { icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), nudConfigs: opts.NUDConfigs, - useNeighborCache: opts.UseNeighborCache, + useNeighborCache: !opts.UseLinkAddrCache, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, randomGenerator: mathrand.New(randSrc), @@ -666,9 +674,6 @@ func New(opts Options) *Stack { for _, netProtoFactory := range opts.NetworkProtocols { netProto := netProtoFactory(s) s.networkProtocols[netProto.Number()] = netProto - if r, ok := netProto.(LinkAddressResolver); ok { - s.linkAddrResolvers[r.LinkAddressProtocol()] = r - } } // Add specified transport protocols. @@ -818,38 +823,37 @@ func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { // // This method takes ownership of the table. func (s *Stack) SetRouteTable(table []tcpip.Route) { - s.mu.Lock() - defer s.mu.Unlock() - - s.routeTable = table + s.route.mu.Lock() + defer s.route.mu.Unlock() + s.route.mu.table = table } // GetRouteTable returns the route table which is currently in use. func (s *Stack) GetRouteTable() []tcpip.Route { - s.mu.Lock() - defer s.mu.Unlock() - return append([]tcpip.Route(nil), s.routeTable...) + s.route.mu.RLock() + defer s.route.mu.RUnlock() + return append([]tcpip.Route(nil), s.route.mu.table...) } // AddRoute appends a route to the route table. func (s *Stack) AddRoute(route tcpip.Route) { - s.mu.Lock() - defer s.mu.Unlock() - s.routeTable = append(s.routeTable, route) + s.route.mu.Lock() + defer s.route.mu.Unlock() + s.route.mu.table = append(s.route.mu.table, route) } // RemoveRoutes removes matching routes from the route table. func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) { - s.mu.Lock() - defer s.mu.Unlock() + s.route.mu.Lock() + defer s.route.mu.Unlock() var filteredRoutes []tcpip.Route - for _, route := range s.routeTable { + for _, route := range s.route.mu.table { if !match(route) { filteredRoutes = append(filteredRoutes, route) } } - s.routeTable = filteredRoutes + s.route.mu.table = filteredRoutes } // NewEndpoint creates a new transport layer endpoint of the given protocol. @@ -1022,17 +1026,18 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error { delete(s.nics, id) // Remove routes in-place. n tracks the number of routes written. + s.route.mu.Lock() n := 0 - for i, r := range s.routeTable { - s.routeTable[i] = tcpip.Route{} + for i, r := range s.route.mu.table { + s.route.mu.table[i] = tcpip.Route{} if r.NIC != id { // Keep this route. - s.routeTable[n] = r + s.route.mu.table[n] = r n++ } } - - s.routeTable = s.routeTable[:n] + s.route.mu.table = s.route.mu.table[:n] + s.route.mu.Unlock() return nic.remove() } @@ -1357,39 +1362,49 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n // Find a route to the remote with the route table. var chosenRoute tcpip.Route - for _, route := range s.routeTable { - if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) { - continue - } + if r := func() *Route { + s.route.mu.RLock() + defer s.route.mu.RUnlock() - nic, ok := s.nics[route.NIC] - if !ok || !nic.Enabled() { - continue - } + for _, route := range s.route.mu.table { + if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) { + continue + } - if id == 0 || id == route.NIC { - if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - var gateway tcpip.Address - if needRoute { - gateway = route.Gateway - } - r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) - if r == nil { - panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) + nic, ok := s.nics[route.NIC] + if !ok || !nic.Enabled() { + continue + } + + if id == 0 || id == route.NIC { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { + var gateway tcpip.Address + if needRoute { + gateway = route.Gateway + } + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) + if r == nil { + panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) + } + return r } - return r, nil } - } - // If the stack has forwarding enabled and we haven't found a valid route to - // the remote address yet, keep track of the first valid route. We keep - // iterating because we prefer routes that let us use a local address that - // is assigned to the outgoing interface. There is no requirement to do this - // from any RFC but simply a choice made to better follow a strong host - // model which the netstack follows at the time of writing. - if canForward && chosenRoute == (tcpip.Route{}) { - chosenRoute = route + // If the stack has forwarding enabled and we haven't found a valid route + // to the remote address yet, keep track of the first valid route. We + // keep iterating because we prefer routes that let us use a local + // address that is assigned to the outgoing interface. There is no + // requirement to do this from any RFC but simply a choice made to better + // follow a strong host model which the netstack follows at the time of + // writing. + if canForward && chosenRoute == (tcpip.Route{}) { + chosenRoute = route + } } + + return nil + }(); r != nil { + return r, nil } if chosenRoute != (tcpip.Route{}) { @@ -1517,20 +1532,6 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) tcpip.Error { return nil } -// AddLinkAddress adds a link address for the neighbor on the specified NIC. -func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { - s.mu.RLock() - defer s.mu.RUnlock() - - nic, ok := s.nics[nicID] - if !ok { - return &tcpip.ErrUnknownNICID{} - } - - nic.linkAddrCache.AddLinkAddress(neighbor, linkAddr) - return nil -} - // LinkResolutionResult is the result of a link address resolution attempt. type LinkResolutionResult struct { LinkAddress tcpip.LinkAddress @@ -1561,22 +1562,11 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, return &tcpip.ErrUnknownNICID{} } - linkRes, ok := s.linkAddrResolvers[protocol] - if !ok { - return &tcpip.ErrNotSupported{} - } - - if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok { - onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true}) - return nil - } - - _, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) - return err + return nic.getLinkAddress(addr, localAddr, protocol, onResolve) } // Neighbors returns all IP to MAC address associations. -func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, tcpip.Error) { +func (s *Stack) Neighbors(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1585,11 +1575,11 @@ func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, tcpip.Error) { return nil, &tcpip.ErrUnknownNICID{} } - return nic.neighbors() + return nic.neighbors(protocol) } // AddStaticNeighbor statically associates an IP address to a MAC address. -func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { +func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1598,13 +1588,13 @@ func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAdd return &tcpip.ErrUnknownNICID{} } - return nic.addStaticNeighbor(addr, linkAddr) + return nic.addStaticNeighbor(addr, protocol, linkAddr) } // RemoveNeighbor removes an IP to MAC address association previously created // either automically or by AddStaticNeighbor. Returns ErrBadAddress if there // is no association with the provided address. -func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) tcpip.Error { +func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1613,11 +1603,11 @@ func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) tcpip.Erro return &tcpip.ErrUnknownNICID{} } - return nic.removeNeighbor(addr) + return nic.removeNeighbor(protocol, addr) } // ClearNeighbors removes all IP to MAC address associations. -func (s *Stack) ClearNeighbors(nicID tcpip.NICID) tcpip.Error { +func (s *Stack) ClearNeighbors(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() @@ -1626,7 +1616,7 @@ func (s *Stack) ClearNeighbors(nicID tcpip.NICID) tcpip.Error { return &tcpip.ErrUnknownNICID{} } - return nic.clearNeighbors() + return nic.clearNeighbors(protocol) } // RegisterTransportEndpoint registers the given endpoint with the stack @@ -1996,7 +1986,7 @@ func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtoco } // NUDConfigurations gets the per-interface NUD configurations. -func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, tcpip.Error) { +func (s *Stack) NUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() @@ -2005,14 +1995,14 @@ func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, tcpip.Erro return NUDConfigurations{}, &tcpip.ErrUnknownNICID{} } - return nic.nudConfigs() + return nic.nudConfigs(proto) } // SetNUDConfigurations sets the per-interface NUD configurations. // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) tcpip.Error { +func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() @@ -2021,7 +2011,7 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) tcpip. return &tcpip.ErrUnknownNICID{} } - return nic.setNUDConfigs(c) + return nic.setNUDConfigs(proto, c) } // Seed returns a 32 bit value that can be used as a seed value for port diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 41f95811f..b641a4aaa 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -137,12 +138,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { return } pkt.Data.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket( + f.dispatcher.DeliverTransportError( tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), - stack.ControlPortUnreachable, 0, pkt) + // Nothing checks the error. + nil, /* transport error */ + pkt, + ) return } @@ -243,7 +247,7 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &fakeNetworkEndpoint{ nic: nic, proto: f, @@ -4313,9 +4317,11 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") ) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, UseNeighborCache: true, + Clock: clock, }) e := channel.New(0, 0, "") e.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -4323,36 +4329,56 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddStaticNeighbor(nicID, ipv4Addr, linkAddr); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv4Addr, linkAddr, err) - } - if err := s.AddStaticNeighbor(nicID, ipv6Addr, linkAddr); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv6Addr, linkAddr, err) + addrs := []struct { + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + }{ + { + proto: ipv4.ProtocolNumber, + addr: ipv4Addr, + }, + { + proto: ipv6.ProtocolNumber, + addr: ipv6Addr, + }, } - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) - } else if len(neighbors) != 2 { - t.Fatalf("got len(neighbors) = %d, want = 2; neighbors = %#v", len(neighbors), neighbors) + for _, addr := range addrs { + if err := s.AddStaticNeighbor(nicID, addr.proto, addr.addr, linkAddr); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, addr.proto, addr.addr, linkAddr, err) + } + + if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) + } else if diff := cmp.Diff( + []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + neighbors, + ); diff != "" { + t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff) + } } // Disabling the NIC should clear the neighbor table. if err := s.DisableNIC(nicID); err != nil { t.Fatalf("s.DisableNIC(%d): %s", nicID, err) } - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) - } else if len(neighbors) != 0 { - t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + for _, addr := range addrs { + if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) + } else if len(neighbors) != 0 { + t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) + } } // Enabling the NIC should have an empty neighbor table. if err := s.EnableNIC(nicID); err != nil { t.Fatalf("s.EnableNIC(%d): %s", nicID, err) } - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) - } else if len(neighbors) != 0 { - t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + for _, addr := range addrs { + if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) + } else if len(neighbors) != 0 { + t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) + } } } @@ -4391,7 +4417,9 @@ func TestStaticGetLinkAddress(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, }) - if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil { + e := channel.New(0, 0, "") + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 26eceb804..7d8d0851e 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -182,9 +182,8 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } -// handleControlPacket delivers a control packet to the transport endpoint -// identified by id. -func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) { +// handleError delivers an error to the transport endpoint identified by id. +func (epsByNIC *endpointsByNIC) handleError(n *NIC, id TransportEndpointID, transErr TransportError, pkt *PacketBuffer) { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -200,7 +199,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(typ, extra, pkt) + selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns @@ -596,9 +595,11 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb return foundRaw } -// deliverControlPacket attempts to deliver the given control packet. Returns -// true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool { +// deliverError attempts to deliver the given error to the appropriate transport +// endpoint. +// +// Returns true if the error was delivered. +func (d *transportDemuxer) deliverError(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -611,7 +612,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco return false } - ep.handleControlPacket(n, id, typ, extra, pkt) + ep.handleError(n, id, transErr, pkt) return true } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index cf5de747b..bebf4e6b5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -237,7 +237,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * f.acceptQueue = append(f.acceptQueue, ep) } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.ControlType, uint32, *stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleError(stack.TransportError, *stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } |