diff options
35 files changed, 830 insertions, 1165 deletions
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 0efbfb22b..d9f8e3b35 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -31,7 +31,7 @@ type PacketInfo struct { Pkt *stack.PacketBuffer Proto tcpip.NetworkProtocolNumber GSO *stack.GSO - Route *stack.Route + Route stack.RouteInfo } // Notification is the interface for receiving notification from the packet @@ -230,15 +230,11 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket stores outbound packets into the channel. func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - // Clone r then release its resource so we only get the relevant fields from - // stack.Route without holding a reference to a NIC's endpoint. - route := r.Clone() - route.Release() p := PacketInfo{ Pkt: pkt, Proto: protocol, GSO: gso, - Route: route, + Route: r.GetFields(), } e.q.Write(p) @@ -248,17 +244,13 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets stores outbound packets into the channel. func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // Clone r then release its resource so we only get the relevant fields from - // stack.Route without holding a reference to a NIC's endpoint. - route := r.Clone() - route.Release() n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ Pkt: pkt, Proto: protocol, GSO: gso, - Route: route, + Route: r.GetFields(), } if !e.q.Write(p) { diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index ce4da7230..a87abc6d6 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -323,9 +323,8 @@ func TestPreserveSrcAddress(t *testing.T) { defer c.cleanup() // Set LocalLinkAddress in route to the value of the bridged address. - r := &stack.Route{ - LocalLinkAddress: baddr, - } + var r stack.Route + r.LocalLinkAddress = baddr r.ResolveWith(raddr) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -335,7 +334,7 @@ func TestPreserveSrcAddress(t *testing.T) { ReserveHeaderBytes: header.EthernetMinimumSize, Data: buffer.VectorisedView{}, }) - if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 3e4afcdad..b511d3a31 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -51,7 +51,8 @@ func TestInjectableEndpointDispatch(t *testing.T) { Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), }) pkt.TransportHeader().Push(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} + var packetRoute stack.Route + packetRoute.RemoteAddress = dstIP endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) @@ -73,7 +74,8 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { Data: buffer.NewView(0).ToVectorisedView(), }) pkt.TransportHeader().Push(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} + var packetRoute stack.Route + packetRoute.RemoteAddress = dstIP endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 27667f5f0..b7458b620 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -154,8 +154,7 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // WritePacket caller's do not set the following fields in PacketBuffer // so we populate them here. - newRoute := r.Clone() - pkt.EgressRoute = newRoute + pkt.EgressRoute = r pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] @@ -178,11 +177,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB for pkt := pkts.Front(); pkt != nil; { d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] nxt := pkt.Next() - // Since qdisc can hold onto a packet for long we should Clone - // the route here to ensure it doesn't get released while the - // packet is still in our queue. - newRoute := pkt.EgressRoute.Clone() - pkt.EgressRoute = newRoute if !d.q.enqueue(pkt) { if enqueued > 0 { d.newPacketWaker.Assert() diff --git a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go index eb5abb906..45adcbccb 100644 --- a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go +++ b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go @@ -61,6 +61,7 @@ func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool { q.mu.Lock() r := q.used < q.limit if r { + s.EgressRoute.Acquire() q.list.PushBack(s) q.used++ } diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 7131392cc..dd2e1a125 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -340,9 +340,8 @@ func TestPreserveSrcAddressInSend(t *testing.T) { newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6)) // Set both remote and local link address in route. - r := stack.Route{ - LocalLinkAddress: newLocalLinkAddress, - } + var r stack.Route + r.LocalLinkAddress = newLocalLinkAddress r.ResolveWith(remoteLinkAddr) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index a364c5801..bfac358f4 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -264,7 +264,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { // If the packet does not already have link layer header, and the route // does not exist, we can't compute it. This is possibly a raw packet, tun // device doesn't support this at the moment. - if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress() == "" { + if info.Pkt.LinkHeader().View().IsEmpty() && len(info.Route.RemoteLinkAddress) == 0 { return nil, false } @@ -272,7 +272,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { if d.hasFlags(linux.IFF_TAP) { // Add ethernet header if not provided. if info.Pkt.LinkHeader().View().IsEmpty() { - d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress(), info.Proto, info.Pkt) + d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) } vv.AppendView(info.Pkt.LinkHeader().View()) } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 0fb373612..a25cba513 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -441,9 +441,8 @@ func (*testInterface) Promiscuous() bool { } func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - r := stack.Route{ - NetProto: protocol, - } + var r stack.Route + r.NetProto = protocol r.ResolveWith(remoteLinkAddr) return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) } @@ -557,8 +556,8 @@ func TestLinkAddressRequest(t *testing.T) { t.Fatal("expected to send a link address request") } - if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr) + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) } rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 9e2d2cfd6..ef62fe6fc 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -2669,8 +2669,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } - if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), @@ -2712,8 +2712,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } - if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), @@ -2761,8 +2761,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != arp.ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) } - if got := p.Route.RemoteLinkAddress(); got != header.EthernetBroadcastAddress { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, header.EthernetBroadcastAddress) + if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) } rep := header.ARP(p.Pkt.NetworkHeader().View()) if got := rep.Op(); got != header.ARPRequest { diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 02b18e9a5..34a6a8446 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -149,9 +149,8 @@ func (*testInterface) Promiscuous() bool { } func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - r := stack.Route{ - NetProto: protocol, - } + var r stack.Route + r.NetProto = protocol r.ResolveWith(remoteLinkAddr) return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) } @@ -600,8 +599,8 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. return } - if got := pi.Route.RemoteLinkAddress(); len(args.remoteLinkAddr) != 0 && got != args.remoteLinkAddr { - t.Errorf("got remote link address = %s, want = %s", got, args.remoteLinkAddr) + if len(args.remoteLinkAddr) != 0 && pi.Route.RemoteLinkAddress != args.remoteLinkAddr { + t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) } // Pull the full payload since network header. Needed for header.IPv6 to @@ -1381,8 +1380,8 @@ func TestLinkAddressRequest(t *testing.T) { if !ok { t.Fatal("expected to send a link address request") } - if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr) + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) } if pkt.Route.RemoteAddress != test.expectedRemoteAddr { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) @@ -1463,8 +1462,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) } - if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1505,8 +1504,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) } - if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1556,8 +1555,8 @@ func TestPacketQueing(t *testing.T) { t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber) } snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address) - if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 05a0d95b2..7ddb19c00 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -650,8 +650,8 @@ func TestNeighorSolicitationResponse(t *testing.T) { if p.Route.RemoteAddress != respNSDst { t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst) } - if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(respNSDst); got != want { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want) + if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), @@ -706,8 +706,8 @@ func TestNeighorSolicitationResponse(t *testing.T) { if p.Route.RemoteAddress != test.naDst { t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) } - if got := p.Route.RemoteLinkAddress(); got != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, test.naDstLinkAddr) + if p.Route.RemoteLinkAddress != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 9cc6074da..bb30556cf 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -148,7 +148,6 @@ go_test( ], library = ":stack", deps = [ - "//pkg/sleep", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 5ec9b3411..93e8e1c51 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -560,6 +560,38 @@ func TestForwardingWithNoResolver(t *testing.T) { } } +func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 50 * time.Millisecond, + onLinkAddressResolved: func(*linkAddrCache, *neighborCache, tcpip.Address, tcpip.LinkAddress) { + // Don't resolve the link address. + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto, true /* useNeighborCache */) + + const numPackets int = 5 + // These packets will all be enqueued in the packet queue to wait for link + // address resolution. + for i := 0; i < numPackets; i++ { + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } + + // All packets should fail resolution. + // TODO(gvisor.dev/issue/5141): Use a fake clock. + for i := 0; i < numPackets; i++ { + select { + case got := <-ep2.C: + t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got) + case <-time.After(100 * time.Millisecond): + } + } +} + func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { tests := []struct { name string diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index c9b13cd0e..792f4f170 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -18,7 +18,6 @@ import ( "fmt" "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -58,9 +57,6 @@ const ( incomplete entryState = iota // ready means that the address has been resolved and can be used. ready - // failed means that address resolution timed out and the address - // could not be resolved. - failed ) // String implements Stringer. @@ -70,8 +66,6 @@ func (s entryState) String() string { return "incomplete" case ready: return "ready" - case failed: - return "failed" default: return fmt.Sprintf("unknown(%d)", s) } @@ -80,40 +74,48 @@ func (s entryState) String() string { // A linkAddrEntry is an entry in the linkAddrCache. // This struct is thread-compatible. type linkAddrEntry struct { + // linkAddrEntryEntry access is synchronized by the linkAddrCache lock. linkAddrEntryEntry + // TODO(gvisor.dev/issue/5150): move these fields under mu. + // mu protects the fields below. + mu sync.RWMutex + addr tcpip.FullAddress linkAddr tcpip.LinkAddress expiration time.Time s entryState - // wakers is a set of waiters for address resolution result. Anytime - // state transitions out of incomplete these waiters are notified. - wakers map[*sleep.Waker]struct{} - - // done is used to allow callers to wait on address resolution. It is nil iff - // s is incomplete and resolution is not yet in progress. + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. done chan struct{} + + // onResolve is called with the result of address resolution. + onResolve []func(tcpip.LinkAddress, bool) } -// changeState sets the entry's state to ns, notifying any waiters. +func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { + for _, callback := range e.onResolve { + callback(linkAddr, len(linkAddr) != 0) + } + e.onResolve = nil + if ch := e.done; ch != nil { + close(ch) + e.done = nil + } +} + +// changeStateLocked sets the entry's state to ns. // // The entry's expiration is bumped up to the greater of itself and the passed // expiration; the zero value indicates immediate expiration, and is set // unconditionally - this is an implementation detail that allows for entries // to be reused. -func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { - // Notify whoever is waiting on address resolution when transitioning - // out of incomplete. - if e.s == incomplete && ns != incomplete { - for w := range e.wakers { - w.Assert() - } - e.wakers = nil - if ch := e.done; ch != nil { - close(ch) - } - e.done = nil +// +// Precondition: e.mu must be locked +func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { + if e.s == incomplete && ns == ready { + e.notifyCompletionLocked(e.linkAddr) } if expiration.IsZero() || expiration.After(e.expiration) { @@ -122,10 +124,6 @@ func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { e.s = ns } -func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { - delete(e.wakers, w) -} - // add adds a k -> v mapping to the cache. func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { // Calculate expiration time before acquiring the lock, since expiration is @@ -135,10 +133,12 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { c.cache.Lock() entry := c.getOrCreateEntryLocked(k) - entry.linkAddr = v - - entry.changeState(ready, expiration) c.cache.Unlock() + + entry.mu.Lock() + defer entry.mu.Unlock() + entry.linkAddr = v + entry.changeStateLocked(ready, expiration) } // getOrCreateEntryLocked retrieves a cache entry associated with k. The @@ -159,13 +159,14 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt var entry *linkAddrEntry if len(c.cache.table) == linkAddrCacheSize { entry = c.cache.lru.Back() + entry.mu.Lock() delete(c.cache.table, entry.addr) c.cache.lru.Remove(entry) - // Wake waiters and mark the soon-to-be-reused entry as expired. Note - // that the state passed doesn't matter when the zero time is passed. - entry.changeState(failed, time.Time{}) + // Wake waiters and mark the soon-to-be-reused entry as expired. + entry.notifyCompletionLocked("" /* linkAddr */) + entry.mu.Unlock() } else { entry = new(linkAddrEntry) } @@ -180,9 +181,12 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if linkRes != nil { if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { + if onResolve != nil { + onResolve(addr, true) + } return addr, nil, nil } } @@ -190,56 +194,35 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo c.cache.Lock() defer c.cache.Unlock() entry := c.getOrCreateEntryLocked(k) + entry.mu.Lock() + defer entry.mu.Unlock() + switch s := entry.s; s { - case ready, failed: + case ready: if !time.Now().After(entry.expiration) { // Not expired. - switch s { - case ready: - return entry.linkAddr, nil, nil - case failed: - return entry.linkAddr, nil, tcpip.ErrNoLinkAddress - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) + if onResolve != nil { + onResolve(entry.linkAddr, true) } + return entry.linkAddr, nil, nil } - entry.changeState(incomplete, time.Time{}) + entry.changeStateLocked(incomplete, time.Time{}) fallthrough case incomplete: - if waker != nil { - if entry.wakers == nil { - entry.wakers = make(map[*sleep.Waker]struct{}) - } - entry.wakers[waker] = struct{}{} + if onResolve != nil { + entry.onResolve = append(entry.onResolve, onResolve) } - if entry.done == nil { - // Address resolution needs to be initiated. - if linkRes == nil { - return entry.linkAddr, nil, tcpip.ErrNoLinkAddress - } - entry.done = make(chan struct{}) go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } - return entry.linkAddr, entry.done, tcpip.ErrWouldBlock default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } } -// removeWaker removes a waker previously added through get(). -func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { - c.cache.Lock() - defer c.cache.Unlock() - - if entry, ok := c.cache.table[k]; ok { - entry.removeWaker(waker) - } -} - func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check @@ -257,9 +240,9 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link } } -// checkLinkRequest checks whether previous attempt to resolve address has succeeded -// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request -// can stop, false if another request should be sent. +// checkLinkRequest checks whether previous attempt to resolve address has +// succeeded and mark the entry accordingly. Returns true if request can stop, +// false if another request should be sent. func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool { c.cache.Lock() defer c.cache.Unlock() @@ -268,16 +251,20 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att // Entry was evicted from the cache. return true } + entry.mu.Lock() + defer entry.mu.Unlock() + switch s := entry.s; s { - case ready, failed: - // Entry was made ready by resolver or failed. Either way we're done. + case ready: + // Entry was made ready by resolver. case incomplete: if attempt+1 < c.resolutionAttempts { // No response yet, need to send another ARP request. return false } - // Max number of retries reached, mark entry as failed. - entry.changeState(failed, now.Add(c.ageLimit)) + // Max number of retries reached, delete entry. + entry.notifyCompletionLocked("" /* linkAddr */) + delete(c.cache.table, k) default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index d2e37f38d..6883045b5 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -21,7 +21,6 @@ import ( "testing" "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -50,6 +49,7 @@ type testLinkAddressResolver struct { } func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { + // TODO(gvisor.dev/issue/5141): Use a fake clock. time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { f() @@ -78,16 +78,18 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe } func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { - w := sleep.Waker{} - s := sleep.Sleeper{} - s.AddWaker(&w, 123) - defer s.Done() - + var attemptedResolution bool for { - if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - return got, err + got, ch, err := c.get(addr, linkRes, "", nil, nil) + if err == tcpip.ErrWouldBlock { + if attemptedResolution { + return got, tcpip.ErrNoLinkAddress + } + attemptedResolution = true + <-ch + continue } - s.Fetch(true) + return got, err } } @@ -116,16 +118,19 @@ func TestCacheOverflow(t *testing.T) { } } // The earliest entries should no longer be in the cache. + c.cache.Lock() + defer c.cache.Unlock() for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { e := testAddrs[i] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) + if entry, ok := c.cache.table[e.addr]; ok { + t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) } } } func TestCacheConcurrent(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + linkRes := &testLinkAddressResolver{cache: c} var wg sync.WaitGroup for r := 0; r < 16; r++ { @@ -133,7 +138,6 @@ func TestCacheConcurrent(t *testing.T) { go func() { for _, e := range testAddrs { c.add(e.addr, e.linkAddr) - c.get(e.addr, nil, "", nil, nil) // make work for gotsan } wg.Done() }() @@ -144,7 +148,7 @@ func TestCacheConcurrent(t *testing.T) { // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) } @@ -153,18 +157,22 @@ func TestCacheConcurrent(t *testing.T) { } e = testAddrs[0] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + c.cache.Lock() + defer c.cache.Unlock() + if entry, ok := c.cache.table[e.addr]; ok { + t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) } } func TestCacheAgeLimit(t *testing.T) { c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) + linkRes := &testLinkAddressResolver{cache: c} + e := testAddrs[0] c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err) } } @@ -282,71 +290,3 @@ func TestStaticResolution(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want)) } } - -// TestCacheWaker verifies that RemoveWaker removes a waker previously added -// through get(). -func TestCacheWaker(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - - // First, sanity check that wakers are working. - { - linkRes := &testLinkAddressResolver{cache: c} - s := sleep.Sleeper{} - defer s.Done() - - const wakerID = 1 - w := sleep.Waker{} - s.AddWaker(&w, wakerID) - - e := testAddrs[0] - - if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) - } - id, ok := s.Fetch(true /* block */) - if !ok { - t.Fatal("got s.Fetch(true) = (_, false), want = (_, true)") - } - if id != wakerID { - t.Fatalf("got s.Fetch(true) = (%d, %t), want = (%d, true)", id, ok, wakerID) - } - - if got, _, err := c.get(e.addr, linkRes, "", nil, nil); err != nil { - t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) - } else if got != e.linkAddr { - t.Fatalf("got c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) - } - } - - // Check that RemoveWaker works. - { - linkRes := &testLinkAddressResolver{cache: c} - s := sleep.Sleeper{} - defer s.Done() - - const wakerID = 2 // different than the ID used in the sanity check - w := sleep.Waker{} - s.AddWaker(&w, wakerID) - - e := testAddrs[1] - linkRes.onLinkAddressRequest = func() { - // Remove the waker before the linkAddrCache has the opportunity to send - // a notification. - c.removeWaker(e.addr, &w) - } - - if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) - } - - if got, err := getBlocking(c, e.addr, linkRes); err != nil { - t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) - } else if got != e.linkAddr { - t.Fatalf("c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) - } - - if id, ok := s.Fetch(false /* block */); ok { - t.Fatalf("unexpected notification from waker with id %d", id) - } - } -} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 03d7b4e0d..61636cae5 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -540,8 +540,8 @@ func TestDADResolve(t *testing.T) { // Make sure the right remote link address is used. snmc := header.SolicitedNodeAddr(addr1) - if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want { - t.Errorf("got remote link address = %s, want = %s", got, want) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) } // Check NDP NS packet. @@ -5197,8 +5197,8 @@ func TestRouterSolicitation(t *testing.T) { } // Make sure the right remote link address is used. - if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); got != want { - t.Errorf("got remote link address = %s, want = %s", got, want) + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 317f6871d..c15f10e76 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -17,7 +17,6 @@ package stack import ( "fmt" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -99,9 +98,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA n.dynamic.lru.Remove(e) n.dynamic.count-- - e.dispatchRemoveEventLocked() - e.setStateLocked(Unknown) - e.notifyWakersLocked() + e.removeLocked() e.mu.Unlock() } n.cache[remoteAddr] = entry @@ -110,21 +107,27 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA return entry } -// entry looks up the neighbor cache for translating address to link address -// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there -// is a LinkAddressResolver registered with the network protocol, the cache -// attempts to resolve the address and returns ErrWouldBlock. If a Waker is -// provided, it will be notified when address resolution is complete (success -// or not). +// entry looks up neighbor information matching the remote address, and returns +// it if readily available. +// +// Returns ErrWouldBlock if the link address is not readily available, along +// with a notification channel for the caller to block on. Triggers address +// resolution asynchronously. +// +// If onResolve is provided, it will be called either immediately, if resolution +// is not required, or when address resolution is complete, with the resolved +// link address and whether resolution succeeded. After any callbacks have been +// called, the returned notification channel is closed. +// +// NB: if a callback is provided, it should not call into the neighbor cache. // // If specified, the local address must be an address local to the interface the // neighbor cache belongs to. The local address is the source address of a // packet prompting NUD/link address resolution. // -// If address resolution is required, ErrNoLinkAddress and a notification -// channel is returned for the top level caller to block. Channel is closed -// once address resolution is complete (success or not). -func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { +// TODO(gvisor.dev/issue/5151): Don't return the neighbor entry. +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) { + // TODO(gvisor.dev/issue/5149): Handle static resolution in route.Resolve. if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { e := NeighborEntry{ Addr: remoteAddr, @@ -132,6 +135,9 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA State: Static, UpdatedAtNanos: 0, } + if onResolve != nil { + onResolve(linkAddr, true) + } return e, nil, nil } @@ -149,37 +155,25 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA // of packets to a neighbor. While reasserting a neighbor's reachability, // a node continues sending packets to that neighbor using the cached // link-layer address." + if onResolve != nil { + onResolve(entry.neigh.LinkAddr, true) + } return entry.neigh, nil, nil - case Unknown, Incomplete: - entry.addWakerLocked(w) - + case Unknown, Incomplete, Failed: + if onResolve != nil { + entry.onResolve = append(entry.onResolve, onResolve) + } if entry.done == nil { // Address resolution needs to be initiated. - if linkRes == nil { - return entry.neigh, nil, tcpip.ErrNoLinkAddress - } entry.done = make(chan struct{}) } - entry.handlePacketQueuedLocked(localAddr) return entry.neigh, entry.done, tcpip.ErrWouldBlock - case Failed: - return entry.neigh, nil, tcpip.ErrNoLinkAddress default: panic(fmt.Sprintf("Invalid cache entry state: %s", s)) } } -// removeWaker removes a waker that has been added when link resolution for -// addr was requested. -func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) { - n.mu.Lock() - if entry, ok := n.cache[addr]; ok { - delete(entry.wakers, waker) - } - n.mu.Unlock() -} - // entries returns all entries in the neighbor cache. func (n *neighborCache) entries() []NeighborEntry { n.mu.RLock() @@ -222,34 +216,13 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd return } - // Notify that resolution has been interrupted, just in case the entry was - // in the Incomplete or Probe state. - entry.dispatchRemoveEventLocked() - entry.setStateLocked(Unknown) - entry.notifyWakersLocked() + entry.removeLocked() entry.mu.Unlock() } n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) } -// removeEntryLocked removes the specified entry from the neighbor cache. -// -// Prerequisite: n.mu and entry.mu MUST be locked. -func (n *neighborCache) removeEntryLocked(entry *neighborEntry) { - if entry.neigh.State != Static { - n.dynamic.lru.Remove(entry) - n.dynamic.count-- - } - if entry.neigh.State != Failed { - entry.dispatchRemoveEventLocked() - } - entry.setStateLocked(Unknown) - entry.notifyWakersLocked() - - delete(n.cache, entry.neigh.Addr) -} - // removeEntry removes a dynamic or static entry by address from the neighbor // cache. Returns true if the entry was found and deleted. func (n *neighborCache) removeEntry(addr tcpip.Address) bool { @@ -264,7 +237,13 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool { entry.mu.Lock() defer entry.mu.Unlock() - n.removeEntryLocked(entry) + if entry.neigh.State != Static { + n.dynamic.lru.Remove(entry) + n.dynamic.count-- + } + + entry.removeLocked() + delete(n.cache, entry.neigh.Addr) return true } @@ -275,9 +254,7 @@ func (n *neighborCache) clear() { for _, entry := range n.cache { entry.mu.Lock() - entry.dispatchRemoveEventLocked() - entry.setStateLocked(Unknown) - entry.notifyWakersLocked() + entry.removeLocked() entry.mu.Unlock() } diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 732a299f7..a2ed6ae2a 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -28,7 +28,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" ) @@ -190,15 +189,18 @@ type testNeighborResolver struct { entries *testEntryStore delay time.Duration onLinkAddressRequest func() + dropReplies bool } var _ LinkAddressResolver = (*testNeighborResolver)(nil) func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { - // Delay handling the request to emulate network latency. - r.clock.AfterFunc(r.delay, func() { - r.fakeRequest(targetAddr) - }) + if !r.dropReplies { + // Delay handling the request to emulate network latency. + r.clock.AfterFunc(r.delay, func() { + r.fakeRequest(targetAddr) + }) + } // Execute post address resolution action, if available. if f := r.onLinkAddressRequest; f != nil { @@ -291,10 +293,10 @@ func TestNeighborCacheEntry(t *testing.T) { entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -327,7 +329,7 @@ func TestNeighborCacheEntry(t *testing.T) { } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } // No more events should have been dispatched. @@ -354,11 +356,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -413,7 +415,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } } @@ -461,7 +463,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { return fmt.Errorf("c.store.entry(%d) not found", i) } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) @@ -513,7 +515,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { } // Expect to find only the most recent entries. The order of entries reported - // by entries() is undeterministic, so entries have to be sorted before + // by entries() is nondeterministic, so entries have to be sorted before // comparison. wantUnsortedEntries := opts.wantStaticEntries for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ { @@ -575,10 +577,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ @@ -650,7 +652,7 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { // Add a static entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) @@ -694,7 +696,7 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) // Add a static entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) @@ -756,7 +758,7 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { // Add a static entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) @@ -826,10 +828,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -907,150 +909,6 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { } } -func TestNeighborCacheNotifiesWaker(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } - - w := sleep.Waker{} - s := sleep.Sleeper{} - const wakerID = 1 - s.AddWaker(&w, wakerID) - - entry, ok := store.entry(0) - if !ok { - t.Fatalf("store.entry(0) not found") - } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) - } - if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) - } - clock.Advance(typicalLatency) - - select { - case <-doneCh: - default: - t.Fatal("expected notification from done channel") - } - - id, ok := s.Fetch(false /* block */) - if !ok { - t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr) - } - if id != wakerID { - t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } -} - -func TestNeighborCacheRemoveWaker(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } - - w := sleep.Waker{} - s := sleep.Sleeper{} - const wakerID = 1 - s.AddWaker(&w, wakerID) - - entry, ok := store.entry(0) - if !ok { - t.Fatalf("store.entry(0) not found") - } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) - } - if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) - } - - // Remove the waker before the neighbor cache has the opportunity to send a - // notification. - neigh.removeWaker(entry.Addr, &w) - clock.Advance(typicalLatency) - - select { - case <-doneCh: - default: - t.Fatal("expected notification from done channel") - } - - if id, ok := s.Fetch(false /* block */); ok { - t.Errorf("unexpected notification from waker with id %d", id) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } -} - func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { config := DefaultNUDConfigurations() // Stay in Reachable so the cache can overflow @@ -1062,12 +920,12 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) if err != nil { - t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1075,7 +933,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { State: Static, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ @@ -1129,10 +987,10 @@ func TestNeighborCacheClear(t *testing.T) { // Add a dynamic entry. entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -1187,7 +1045,7 @@ func TestNeighborCacheClear(t *testing.T) { } } - // Clear shoud remove both dynamic and static entries. + // Clear should remove both dynamic and static entries. neigh.clear() // Remove events dispatched from clear() have no deterministic order so they @@ -1234,10 +1092,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -1318,7 +1176,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { frequentlyUsedEntry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } // The following logic is very similar to overflowCache, but @@ -1330,15 +1188,22 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { - case <-doneCh: + case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { @@ -1373,7 +1238,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", frequentlyUsedEntry.Addr, err) } } @@ -1381,15 +1246,23 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { - case <-doneCh: + case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy @@ -1435,7 +1308,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Expect to find only the frequently used entry and the most recent entries. - // The order of entries reported by entries() is undeterministic, so entries + // The order of entries reported by entries() is nondeterministic, so entries // have to be sorted before comparison. wantUnsortedEntries := []NeighborEntry{ { @@ -1494,12 +1367,12 @@ func TestNeighborCacheConcurrent(t *testing.T) { go func(entry NeighborEntry) { defer wg.Done() if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) } }(entry) } - // Wait for all gorountines to send a request + // Wait for all goroutines to send a request wg.Wait() // Process all the requests for a single entry concurrently @@ -1509,7 +1382,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { // All goroutines add in the same order and add more values than can fit in // the cache. Our eviction strategy requires that the last entries are // present, up to the size of the neighbor cache, and the rest are missing. - // The order of entries reported by entries() is undeterministic, so entries + // The order of entries reported by entries() is nondeterministic, so entries // have to be sorted before comparison. var wantUnsortedEntries []NeighborEntry for i := store.size() - neighborCacheSize; i < store.size(); i++ { @@ -1547,27 +1420,32 @@ func TestNeighborCacheReplace(t *testing.T) { // Add an entry entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { - case <-doneCh: + case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } // Verify the entry exists { - e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) - } - if doneCh != nil { - t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh) + t.Errorf("unexpected error from neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) } if t.Failed() { t.FailNow() @@ -1578,7 +1456,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } @@ -1587,7 +1465,7 @@ func TestNeighborCacheReplace(t *testing.T) { { entry, ok := store.entry(1) if !ok { - t.Fatalf("store.entry(1) not found") + t.Fatal("store.entry(1) not found") } updatedLinkAddr = entry.LinkAddr } @@ -1604,7 +1482,7 @@ func TestNeighborCacheReplace(t *testing.T) { { e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Fatalf("neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1612,7 +1490,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Delay, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } @@ -1622,7 +1500,7 @@ func TestNeighborCacheReplace(t *testing.T) { e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1630,7 +1508,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } } @@ -1654,18 +1532,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { }, } - // First, sanity check that resolution is working entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + + // First, sanity check that resolution is working + { + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + clock.Advance(typicalLatency) + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } } - clock.Advance(typicalLatency) + got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1673,20 +1568,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { State: Reachable, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } - // Verify that address resolution for an unknown address returns ErrNoLinkAddress + // Verify address resolution fails for an unknown address. before := atomic.LoadUint32(&requestCount) entry.Addr += "2" - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) - } - waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) - clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) + { + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if ok { + t.Error("expected unsuccessful address resolution") + } + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) + } + if t.Failed() { + t.FailNow() + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) + clock.Advance(waitFor) + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } } maxAttempts := neigh.config().MaxUnicastProbes @@ -1714,15 +1624,129 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if ok { + t.Error("expected unsuccessful address resolution") + } + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) + } + if t.Failed() { + t.FailNow() + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) + + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } +} + +// TestNeighborCacheRetryResolution simulates retrying communication after +// failing to perform address resolution. +func TestNeighborCacheRetryResolution(t *testing.T) { + config := DefaultNUDConfigurations() + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + // Simulate a faulty link. + dropReplies: true, + } + + entry, ok := store.entry(0) + if !ok { + t.Fatal("store.entry(0) not found") + } + + // Perform address resolution with a faulty link, which will fail. + { + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if ok { + t.Error("expected unsuccessful address resolution") + } + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) + } + if t.Failed() { + t.FailNow() + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) + clock.Advance(waitFor) + + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } + } + + // Verify the entry is in Failed state. + wantEntries := []NeighborEntry{ + { + Addr: entry.Addr, + LinkAddr: "", + State: Failed, + }, + } + if diff := cmp.Diff(neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { + t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + + // Retry address resolution with a working link. + linkRes.dropReplies = false + { + incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + if incompleteEntry.State != Incomplete { + t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) + } + clock.Advance(typicalLatency) + + select { + case <-ch: + if !ok { + t.Fatal("expected successful address resolution") + } + reachableEntry, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if err != nil { + t.Fatalf("neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) + } + if reachableEntry.Addr != entry.Addr { + t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr) + } + if reachableEntry.LinkAddr != entry.LinkAddr { + t.Fatalf("got entry.LinkAddr = %s, want = %s", reachableEntry.LinkAddr, entry.LinkAddr) + } + if reachableEntry.State != Reachable { + t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String()) + } + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } } } @@ -1742,7 +1766,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) { got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", testEntryBroadcastAddr, err) } want := NeighborEntry{ Addr: testEntryBroadcastAddr, @@ -1750,7 +1774,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) { State: Static, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) } } @@ -1775,12 +1799,23 @@ func BenchmarkCacheClear(b *testing.B) { if !ok { b.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + b.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } - if doneCh != nil { - <-doneCh + + select { + case <-ch: + default: + b.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } } diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 32399b4f5..75afb3001 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -19,7 +19,6 @@ import ( "sync" "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -67,8 +66,7 @@ const ( // Static describes entries that have been explicitly added by the user. They // do not expire and are not deleted until explicitly removed. Static - // Failed means traffic should not be sent to this neighbor since attempts of - // reachability have returned inconclusive. + // Failed means recent attempts of reachability have returned inconclusive. Failed ) @@ -93,16 +91,13 @@ type neighborEntry struct { neigh NeighborEntry - // wakers is a set of waiters for address resolution result. Anytime state - // transitions out of incomplete these waiters are notified. It is nil iff - // address resolution is ongoing and no clients are waiting for the result. - wakers map[*sleep.Waker]struct{} - - // done is used to allow callers to wait on address resolution. It is nil - // iff nudState is not Reachable and address resolution is not yet in - // progress. + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. done chan struct{} + // onResolve is called with the result of address resolution. + onResolve []func(tcpip.LinkAddress, bool) + isRouter bool job *tcpip.Job } @@ -143,25 +138,15 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd } } -// addWaker adds w to the list of wakers waiting for address resolution. -// Assumes the entry has already been appropriately locked. -func (e *neighborEntry) addWakerLocked(w *sleep.Waker) { - if w == nil { - return - } - if e.wakers == nil { - e.wakers = make(map[*sleep.Waker]struct{}) - } - e.wakers[w] = struct{}{} -} - -// notifyWakersLocked notifies those waiting for address resolution, whether it -// succeeded or failed. Assumes the entry has already been appropriately locked. -func (e *neighborEntry) notifyWakersLocked() { - for w := range e.wakers { - w.Assert() +// notifyCompletionLocked notifies those waiting for address resolution, with +// the link address if resolution completed successfully. +// +// Precondition: e.mu MUST be locked. +func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { + for _, callback := range e.onResolve { + callback(e.neigh.LinkAddr, succeeded) } - e.wakers = nil + e.onResolve = nil if ch := e.done; ch != nil { close(ch) e.done = nil @@ -170,6 +155,8 @@ func (e *neighborEntry) notifyWakersLocked() { // dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has // been added. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchAddEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { nudDisp.OnNeighborAdded(e.nic.id, e.neigh) @@ -178,6 +165,8 @@ func (e *neighborEntry) dispatchAddEventLocked() { // dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry // has changed state or link-layer address. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchChangeEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { nudDisp.OnNeighborChanged(e.nic.id, e.neigh) @@ -186,23 +175,41 @@ func (e *neighborEntry) dispatchChangeEventLocked() { // dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry // has been removed. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchRemoveEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) } } +// cancelJobLocked cancels the currently scheduled action, if there is one. +// Entries in Unknown, Stale, or Static state do not have a scheduled action. +// +// Precondition: e.mu MUST be locked. +func (e *neighborEntry) cancelJobLocked() { + if job := e.job; job != nil { + job.Cancel() + } +} + +// removeLocked prepares the entry for removal. +// +// Precondition: e.mu MUST be locked. +func (e *neighborEntry) removeLocked() { + e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.dispatchRemoveEventLocked() + e.cancelJobLocked() + e.notifyCompletionLocked(false /* succeeded */) +} + // setStateLocked transitions the entry to the specified state immediately. // // Follows the logic defined in RFC 4861 section 7.3.3. // -// e.mu MUST be locked. +// Precondition: e.mu MUST be locked. func (e *neighborEntry) setStateLocked(next NeighborState) { - // Cancel the previously scheduled action, if there is one. Entries in - // Unknown, Stale, or Static state do not have scheduled actions. - if timer := e.job; timer != nil { - timer.Cancel() - } + e.cancelJobLocked() prev := e.neigh.State e.neigh.State = next @@ -257,11 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { e.job.Schedule(immediateDuration) case Failed: - e.notifyWakersLocked() - e.job = e.nic.stack.newJob(&doubleLock{first: &e.nic.neigh.mu, second: &e.mu}, func() { - e.nic.neigh.removeEntryLocked(e) - }) - e.job.Schedule(config.UnreachableTime) + e.notifyCompletionLocked(false /* succeeded */) case Unknown, Stale, Static: // Do nothing @@ -275,8 +278,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // being queued for outgoing transmission. // // Follows the logic defined in RFC 4861 section 7.3.3. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { + case Failed: + e.nic.stats.Neighbor.FailedEntryLookups.Increment() + + fallthrough case Unknown: e.neigh.State = Incomplete e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() @@ -309,7 +318,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // implementation may find it convenient in some cases to return errors // to the sender by taking the offending packet, generating an ICMP // error message, and then delivering it (locally) through the generic - // error-handling routines.' - RFC 4861 section 2.1 + // error-handling routines." - RFC 4861 section 2.1 e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return @@ -349,8 +358,6 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { case Incomplete, Reachable, Delay, Probe, Static: // Do nothing - case Failed: - e.nic.stats.Neighbor.FailedEntryLookups.Increment() default: panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) } @@ -360,18 +367,30 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // Neighbor Solicitation for ARP or NDP, respectively). // // Follows the logic defined in RFC 4861 section 7.2.3. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // Probes MUST be silently discarded if the target address is tentative, does // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These // checks MUST be done by the NetworkEndpoint. switch e.neigh.State { - case Unknown, Incomplete, Failed: + case Unknown, Failed: e.neigh.LinkAddr = remoteLinkAddr e.setStateLocked(Stale) - e.notifyWakersLocked() e.dispatchAddEventLocked() + case Incomplete: + // "If an entry already exists, and the cached link-layer address + // differs from the one in the received Source Link-Layer option, the + // cached address should be replaced by the received address, and the + // entry's reachability state MUST be set to STALE." + // - RFC 4861 section 7.2.3 + e.neigh.LinkAddr = remoteLinkAddr + e.setStateLocked(Stale) + e.notifyCompletionLocked(true /* succeeded */) + e.dispatchChangeEventLocked() + case Reachable, Delay, Probe: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr @@ -404,6 +423,8 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // not be possible. SEND uses RSA key pairs to produce Cryptographically // Generated Addresses (CGA), as defined in RFC 3972. This ensures that the // claimed source of an NDP message is the owner of the claimed address. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { switch e.neigh.State { case Incomplete: @@ -422,7 +443,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla } e.dispatchChangeEventLocked() e.isRouter = flags.IsRouter - e.notifyWakersLocked() + e.notifyCompletionLocked(true /* succeeded */) // "Note that the Override flag is ignored if the entry is in the // INCOMPLETE state." - RFC 4861 section 7.2.5 @@ -457,7 +478,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla wasReachable := e.neigh.State == Reachable // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) - e.notifyWakersLocked() + e.notifyCompletionLocked(true /* succeeded */) if !wasReachable { e.dispatchChangeEventLocked() } @@ -495,6 +516,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // handleUpperLevelConfirmationLocked processes an incoming upper-level protocol // (e.g. TCP acknowledgements) reachability confirmation. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handleUpperLevelConfirmationLocked() { switch e.neigh.State { case Reachable, Stale, Delay, Probe: @@ -512,23 +535,3 @@ func (e *neighborEntry) handleUpperLevelConfirmationLocked() { panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) } } - -// doubleLock combines two locks into one while maintaining lock ordering. -// -// TODO(gvisor.dev/issue/4796): Remove this once subsequent traffic to a Failed -// neighbor is allowed. -type doubleLock struct { - first, second sync.Locker -} - -// Lock locks both locks in order: first then second. -func (l *doubleLock) Lock() { - l.first.Lock() - l.second.Lock() -} - -// Unlock unlocks both locks in reverse order: second then first. -func (l *doubleLock) Unlock() { - l.second.Unlock() - l.first.Unlock() -} diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index c497d3932..ec34ffa5a 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -25,7 +25,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -73,36 +72,36 @@ func eventDiffOptsWithSort() []cmp.Option { // The following unit tests exercise every state transition and verify its // behavior with RFC 4681. // -// | From | To | Cause | Action | Event | -// | ========== | ========== | ========================================== | =============== | ======= | -// | Unknown | Unknown | Confirmation w/ unknown address | | Added | -// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added | -// | Unknown | Stale | Probe w/ unknown address | | Added | -// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed | -// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed | -// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed | -// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed | -// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | | -// | Reachable | Stale | Reachable timer expired | | Changed | -// | Reachable | Stale | Probe or confirmation w/ different address | | Changed | -// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed | -// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | -// | Stale | Stale | Override confirmation | Update LinkAddr | Changed | -// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed | -// | Stale | Delay | Packet queued | | Changed | -// | Delay | Reachable | Upper-layer confirmation | | Changed | -// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed | -// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | -// | Delay | Stale | Probe or confirmation w/ different address | | Changed | -// | Delay | Probe | Delay timer expired | Send probe | Changed | -// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed | -// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed | -// | Probe | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | -// | Probe | Stale | Probe or confirmation w/ different address | | Changed | -// | Probe | Probe | Retransmit timer expired | Send probe | Changed | -// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed | -// | Failed | Failed | Packet queued | | | -// | Failed | | Unreachability timer expired | Delete entry | | +// | From | To | Cause | Update | Action | Event | +// | ========== | ========== | ========================================== | ======== | ===========| ======= | +// | Unknown | Unknown | Confirmation w/ unknown address | | | Added | +// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added | +// | Unknown | Stale | Probe | | | Added | +// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed | +// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed | +// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed | +// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed | +// | Incomplete | Failed | Max probes sent without reply | | Notify | Removed | +// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | | +// | Reachable | Stale | Reachable timer expired | | | Changed | +// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed | +// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Stale | Stale | Override confirmation | LinkAddr | | Changed | +// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed | +// | Stale | Delay | Packet sent | | | Changed | +// | Delay | Reachable | Upper-layer confirmation | | | Changed | +// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Delay | Stale | Probe or confirmation w/ different address | | | Changed | +// | Delay | Probe | Delay timer expired | | Send probe | Changed | +// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed | +// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Probe | Stale | Probe or confirmation w/ different address | | | Changed | +// | Probe | Probe | Retransmit timer expired | | | Changed | +// | Probe | Failed | Max probes sent without reply | | Notify | Removed | +// | Failed | Incomplete | Packet queued | | Send probe | Added | type testEntryEventType uint8 @@ -258,8 +257,8 @@ func TestEntryInitiallyUnknown(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - if got, want := e.neigh.State, Unknown; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Unknown { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown) } e.mu.Unlock() @@ -291,8 +290,8 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { Override: false, IsRouter: false, }) - if got, want := e.neigh.State, Unknown; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Unknown { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown) } e.mu.Unlock() @@ -320,8 +319,8 @@ func TestEntryUnknownToIncomplete(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() @@ -367,8 +366,8 @@ func TestEntryUnknownToStale(t *testing.T) { e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -406,8 +405,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } updatedAtNanos := e.neigh.UpdatedAtNanos e.mu.Unlock() @@ -560,21 +559,15 @@ func TestEntryIncompleteToReachable(t *testing.T) { nudDisp.mu.Unlock() } -// TestEntryAddsAndClearsWakers verifies that wakers are added when -// addWakerLocked is called and cleared when address resolution finishes. In -// this case, address resolution will finish when transitioning from Incomplete -// to Reachable. -func TestEntryAddsAndClearsWakers(t *testing.T) { +func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) - w := sleep.Waker{} - s := sleep.Sleeper{} - s.AddWaker(&w, 123) - defer s.Done() - e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + } e.mu.Unlock() runImmediatelyScheduledJobs(clock) @@ -593,26 +586,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { } e.mu.Lock() - if got := e.wakers; got != nil { - t.Errorf("got e.wakers = %v, want = nil", got) - } - e.addWakerLocked(&w) - if got, want := w.IsAsserted(), false; got != want { - t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) - } - if e.wakers == nil { - t.Error("expected e.wakers to be non-nil") - } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, - IsRouter: false, + IsRouter: true, }) - if e.wakers != nil { - t.Errorf("got e.wakers = %v, want = nil", e.wakers) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } - if got, want := w.IsAsserted(), true; got != want { - t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) + if !e.isRouter { + t.Errorf("got e.isRouter = %t, want = true", e.isRouter) } e.mu.Unlock() @@ -643,7 +626,7 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { +func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) @@ -663,22 +646,20 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { }, } linkRes.mu.Lock() - if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" { + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - linkRes.mu.Unlock() e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, + Solicited: false, Override: false, - IsRouter: true, + IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) - } - if !e.isRouter { - t.Errorf("got e.isRouter = %t, want = true", e.isRouter) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -698,7 +679,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { Entry: NeighborEntry{ Addr: entryTestAddr1, LinkAddr: entryTestLinkAddr1, - State: Reachable, + State: Stale, }, }, } @@ -709,7 +690,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryIncompleteToStale(t *testing.T) { +func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) @@ -736,11 +717,7 @@ func TestEntryIncompleteToStale(t *testing.T) { } e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) + e.handleProbeLocked(entryTestLinkAddr1) if e.neigh.State != Stale { t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } @@ -780,8 +757,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() @@ -841,8 +818,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Failed; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) } e.mu.Unlock() } @@ -885,8 +862,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { Override: false, IsRouter: true, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } if got, want := e.isRouter, true; got != want { t.Errorf("got e.isRouter = %t, want = %t", got, want) @@ -932,8 +909,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } e.mu.Unlock() } @@ -1083,8 +1060,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() } @@ -2381,8 +2358,8 @@ func TestEntryDelayToProbe(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) } e.mu.Unlock() @@ -2447,8 +2424,8 @@ func TestEntryDelayToProbe(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.mu.Unlock() } @@ -2505,12 +2482,12 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -2620,16 +2597,16 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -2740,16 +2717,16 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) @@ -2836,16 +2813,16 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) @@ -2964,16 +2941,16 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) @@ -3101,16 +3078,16 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } e.mu.Unlock() @@ -3435,212 +3412,61 @@ func TestEntryProbeToFailed(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryFailedToFailed(t *testing.T) { +func TestEntryFailedToIncomplete(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 - c.MaxUnicastProbes = 3 e, nudDisp, linkRes, clock := entryTestSetup(c) - // Verify the cache contains the entry. - if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok { - t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1) - } - // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in // their expected state. e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() - waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) clock.Advance(waitFor) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - } - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, + wantProbes := []entryTestProbeInfo{ + // The Incomplete-to-Incomplete state transition is tested here by + // verifying that 3 reachability probes were sent. { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, }, { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - }, + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, }, { - EventType: entryTestRemoved, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - }, + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, }, } - nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } - nudDisp.mu.Unlock() - - failedLookups := e.nic.stats.Neighbor.FailedEntryLookups - if got := failedLookups.Value(); got != 0 { - t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 0", got) + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } e.mu.Lock() - // Verify queuing a packet to the entry immediately fails. - e.handlePacketQueuedLocked(entryTestAddr2) - state := e.neigh.State - e.mu.Unlock() - if state != Failed { - t.Errorf("got e.neigh.State = %q, want = %q", state, Failed) - } - - if got := failedLookups.Value(); got != 1 { - t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 1", got) - } -} - -func TestEntryFailedGetsDeleted(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - c.MaxUnicastProbes = 3 - e, nudDisp, linkRes, clock := entryTestSetup(c) - - // Verify the cache contains the entry. - if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok { - t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1) + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) } - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - } - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime - clock.Advance(waitFor) - { - wantProbes := []entryTestProbeInfo{ - // The next three probe are sent in Probe. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } + e.mu.Unlock() wantEvents := []testEntryEventInfo{ { @@ -3653,39 +3479,21 @@ func TestEntryFailedGetsDeleted(t *testing.T) { }, }, { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, + EventType: entryTestRemoved, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, }, }, { - EventType: entryTestRemoved, + EventType: entryTestAdded, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, }, }, } @@ -3694,9 +3502,4 @@ func TestEntryFailedGetsDeleted(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - // Verify the cache no longer contains the entry. - if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok { - t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1) - } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 5d037a27e..4a34805b5 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -20,7 +20,6 @@ import ( "reflect" "sync/atomic" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -295,15 +294,17 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb // the same unresolved IP address, and transmit the saved // packet when the address has been resolved. // - // RFC 4861 section 5.2 (for IPv6): - // Once the IP address of the next-hop node is known, the sender - // examines the Neighbor Cache for link-layer information about that - // neighbor. If no entry exists, the sender creates one, sets its state - // to INCOMPLETE, initiates Address Resolution, and then queues the data - // packet pending completion of address resolution. + // RFC 4861 section 7.2.2 (for IPv6): + // While waiting for address resolution to complete, the sender MUST, for + // each neighbor, retain a small queue of packets waiting for address + // resolution to complete. The queue MUST hold at least one packet, and MAY + // contain more. However, the number of queued packets per neighbor SHOULD + // be limited to some small value. When a queue overflows, the new arrival + // SHOULD replace the oldest entry. Once address resolution completes, the + // node transmits any queued packets. if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { - r := r.Clone() + r.Acquire() n.stack.linkResQueue.enqueue(ch, r, protocol, pkt) return nil } @@ -316,7 +317,9 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb // WritePacketToRemote implements NetworkInterface. func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { r := Route{ - NetProto: protocol, + routeInfo: routeInfo{ + NetProto: protocol, + }, } r.ResolveWith(remoteLinkAddr) return n.writePacket(&r, gso, protocol, pkt) @@ -545,14 +548,6 @@ func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { return n.neigh.entries(), nil } -func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) { - if n.neigh == nil { - return - } - - n.neigh.removeWaker(addr, w) -} - func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error { if n.neigh == nil { return tcpip.ErrNotSupported diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 89ea2da26..12d67409a 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -109,14 +109,6 @@ const ( // // Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10. defaultMaxReachbilityConfirmations = 3 - - // defaultUnreachableTime is the default duration for how long an entry will - // remain in the FAILED state before being removed from the neighbor cache. - // - // Note, there is no equivalent protocol constant defined in RFC 4861. It - // leaves the specifics of any garbage collection mechanism up to the - // implementation. - defaultUnreachableTime = 5 * time.Second ) // NUDDispatcher is the interface integrators of netstack must implement to @@ -278,10 +270,6 @@ type NUDConfigurations struct { // TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD // configuration option is necessary. MaxReachabilityConfirmations uint32 - - // UnreachableTime describes how long an entry will remain in the FAILED - // state before being removed from the neighbor cache. - UnreachableTime time.Duration } // DefaultNUDConfigurations returns a NUDConfigurations populated with default @@ -299,7 +287,6 @@ func DefaultNUDConfigurations() NUDConfigurations { MaxUnicastProbes: defaultMaxUnicastProbes, MaxAnycastDelayTime: defaultMaxAnycastDelayTime, MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations, - UnreachableTime: defaultUnreachableTime, } } @@ -329,9 +316,6 @@ func (c *NUDConfigurations) resetInvalidFields() { if c.MaxUnicastProbes == 0 { c.MaxUnicastProbes = defaultMaxUnicastProbes } - if c.UnreachableTime == 0 { - c.UnreachableTime = defaultUnreachableTime - } } // calcMaxRandomFactor calculates the maximum value of the random factor used diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 8cffb9fc6..7bca1373e 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -37,7 +37,6 @@ const ( defaultMaxUnicastProbes = 3 defaultMaxAnycastDelayTime = time.Second defaultMaxReachbilityConfirmations = 3 - defaultUnreachableTime = 5 * time.Second defaultFakeRandomNum = 0.5 ) @@ -565,58 +564,6 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { } } -func TestNUDConfigurationsUnreachableTime(t *testing.T) { - tests := []struct { - name string - unreachableTime time.Duration - want time.Duration - }{ - // Invalid cases - { - name: "EqualToZero", - unreachableTime: 0, - want: defaultUnreachableTime, - }, - // Valid cases - { - name: "MoreThanZero", - unreachableTime: time.Millisecond, - want: time.Millisecond, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.UnreachableTime = test.unreachableTime - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - UseNeighborCache: true, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) - } - if got := sc.UnreachableTime; got != test.want { - t.Errorf("got UnreachableTime = %q, want = %q", got, test.want) - } - }) - } -} - // TestNUDStateReachableTime verifies the correctness of the ReachableTime // computation. func TestNUDStateReachableTime(t *testing.T) { diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 5d364a2b0..4a3adcf33 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -103,7 +103,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro for _, p := range packets { if cancelled { p.route.Stats().IP.OutgoingPacketErrors.Increment() - } else if _, err := p.route.Resolve(nil); err != nil { + } else if p.route.IsResolutionRequired() { p.route.Stats().IP.OutgoingPacketErrors.Increment() } else { p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index b334e27c4..7e83b7fbb 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -17,7 +17,6 @@ package stack import ( "fmt" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -799,19 +798,26 @@ type LinkAddressCache interface { // AddLinkAddress adds a link address to the cache. AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) - // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC). - // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver - // registered with the network protocol, the cache attempts to resolve the address - // and returns ErrWouldBlock. Waker is notified when address resolution is - // complete (success or not). + // GetLinkAddress finds the link address corresponding to the remote address + // (e.g. IP -> MAC). // - // If address resolution is required, ErrNoLinkAddress and a notification channel is - // returned for the top level caller to block. Channel is closed once address resolution - // is complete (success or not). - GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) - - // RemoveWaker removes a waker that has been added in GetLinkAddress(). - RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) + // Returns a link address for the remote address, if readily available. + // + // Returns ErrWouldBlock if the link address is not readily available, along + // with a notification channel for the caller to block on. Triggers address + // resolution asynchronously. + // + // If onResolve is provided, it will be called either immediately, if + // resolution is not required, or when address resolution is complete, with + // the resolved link address and whether resolution succeeded. After any + // callbacks have been called, the returned notification channel is closed. + // + // If specified, the local address must be an address local to the interface + // the neighbor cache belongs to. The local address is the source address of + // a packet prompting NUD/link address resolution. + // + // TODO(gvisor.dev/issue/5151): Don't return the link address. + GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) } // RawFactory produces endpoints for writing various types of raw packets. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index de5fe6ffe..b0251d0b4 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -17,7 +17,6 @@ package stack import ( "fmt" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -31,24 +30,7 @@ import ( // // TODO(gvisor.dev/issue/4902): Unexpose immutable fields. type Route struct { - // RemoteAddress is the final destination of the route. - RemoteAddress tcpip.Address - - // LocalAddress is the local address where the route starts. - LocalAddress tcpip.Address - - // LocalLinkAddress is the link-layer (MAC) address of the - // where the route starts. - LocalLinkAddress tcpip.LinkAddress - - // NextHop is the next node in the path to the destination. - NextHop tcpip.Address - - // NetProto is the network-layer protocol. - NetProto tcpip.NetworkProtocolNumber - - // Loop controls where WritePacket should send packets. - Loop PacketLooping + routeInfo // localAddressNIC is the interface the address is associated with. // TODO(gvisor.dev/issue/4548): Remove this field once we can query the @@ -78,6 +60,45 @@ type Route struct { linkRes LinkAddressResolver } +type routeInfo struct { + // RemoteAddress is the final destination of the route. + RemoteAddress tcpip.Address + + // LocalAddress is the local address where the route starts. + LocalAddress tcpip.Address + + // LocalLinkAddress is the link-layer (MAC) address of the + // where the route starts. + LocalLinkAddress tcpip.LinkAddress + + // NextHop is the next node in the path to the destination. + NextHop tcpip.Address + + // NetProto is the network-layer protocol. + NetProto tcpip.NetworkProtocolNumber + + // Loop controls where WritePacket should send packets. + Loop PacketLooping +} + +// RouteInfo contains all of Route's exported fields. +type RouteInfo struct { + routeInfo + + // RemoteLinkAddress is the link-layer (MAC) address of the next hop in the + // route. + RemoteLinkAddress tcpip.LinkAddress +} + +// GetFields returns a RouteInfo with all of r's exported fields. This allows +// callers to store the route's fields without retaining a reference to it. +func (r *Route) GetFields() RouteInfo { + return RouteInfo{ + routeInfo: r.routeInfo, + RemoteLinkAddress: r.RemoteLinkAddress(), + } +} + // constructAndValidateRoute validates and initializes a route. It takes // ownership of the provided local address. // @@ -152,13 +173,15 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route { r := &Route{ - NetProto: netProto, - LocalAddress: localAddr, - LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), - RemoteAddress: remoteAddr, - localAddressNIC: localAddressNIC, - outgoingNIC: outgoingNIC, - Loop: loop, + routeInfo: routeInfo{ + NetProto: netProto, + LocalAddress: localAddr, + LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), + RemoteAddress: remoteAddr, + Loop: loop, + }, + localAddressNIC: localAddressNIC, + outgoingNIC: outgoingNIC, } r.mu.Lock() @@ -264,22 +287,21 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) { r.mu.remoteLinkAddress = addr } -// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in -// case address resolution requires blocking, e.g. wait for ARP reply. Waker is -// notified when address resolution is complete (success or not). +// Resolve attempts to resolve the link address if necessary. // -// If address resolution is required, ErrNoLinkAddress and a notification channel is -// returned for the top level caller to block. Channel is closed once address resolution -// is complete (success or not). -// -// The NIC r uses must not be locked. -func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { +// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g. +// waiting for ARP reply). If address resolution is required, a notification +// channel is also returned for the caller to block on. The channel is closed +// once address resolution is complete (successful or not). If a callback is +// provided, it will be called when address resolution is complete, regardless +// of success or failure. +func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { r.mu.Lock() - defer r.mu.Unlock() if !r.isResolutionRequiredRLocked() { // Nothing to do if there is no cache (which does the resolution on cache miss) or // link address is already known. + r.mu.Unlock() return nil, nil } @@ -288,6 +310,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { // Local link address is already known. if r.RemoteAddress == r.LocalAddress { r.mu.remoteLinkAddress = r.LocalLinkAddress + r.mu.Unlock() return nil, nil } nextAddr = r.RemoteAddress @@ -300,38 +323,36 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { linkAddressResolutionRequestLocalAddr = r.LocalAddress } + // Increment the route's reference count because finishResolution retains a + // reference to the route and releases it when called. + r.acquireLocked() + r.mu.Unlock() + + finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) { + if ok { + r.ResolveWith(linkAddress) + } + if afterResolve != nil { + afterResolve() + } + r.Release() + } + if neigh := r.outgoingNIC.neigh; neigh != nil { - entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker) + _, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution) if err != nil { return ch, err } - r.mu.remoteLinkAddress = entry.LinkAddr return nil, nil } - linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker) + _, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, finishResolution) if err != nil { return ch, err } - r.mu.remoteLinkAddress = linkAddr return nil, nil } -// RemoveWaker removes a waker that has been added in Resolve(). -func (r *Route) RemoveWaker(waker *sleep.Waker) { - nextAddr := r.NextHop - if nextAddr == "" { - nextAddr = r.RemoteAddress - } - - if neigh := r.outgoingNIC.neigh; neigh != nil { - neigh.removeWaker(nextAddr, waker) - return - } - - r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker) -} - // local returns true if the route is a local route. func (r *Route) local() bool { return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback() @@ -419,46 +440,31 @@ func (r *Route) MTU() uint32 { return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU() } -// Release frees all resources associated with the route. +// Release decrements the reference counter of the resources associated with the +// route. func (r *Route) Release() { r.mu.Lock() defer r.mu.Unlock() - if r.mu.localAddressEndpoint != nil { - r.mu.localAddressEndpoint.DecRef() - r.mu.localAddressEndpoint = nil + if ep := r.mu.localAddressEndpoint; ep != nil { + ep.DecRef() } } -// Clone clones the route. -func (r *Route) Clone() *Route { +// Acquire increments the reference counter of the resources associated with the +// route. +func (r *Route) Acquire() { r.mu.RLock() defer r.mu.RUnlock() + r.acquireLocked() +} - newRoute := &Route{ - RemoteAddress: r.RemoteAddress, - LocalAddress: r.LocalAddress, - LocalLinkAddress: r.LocalLinkAddress, - NextHop: r.NextHop, - NetProto: r.NetProto, - Loop: r.Loop, - localAddressNIC: r.localAddressNIC, - outgoingNIC: r.outgoingNIC, - linkCache: r.linkCache, - linkRes: r.linkRes, - } - - newRoute.mu.Lock() - defer newRoute.mu.Unlock() - newRoute.mu.localAddressEndpoint = r.mu.localAddressEndpoint - if newRoute.mu.localAddressEndpoint != nil { - if !newRoute.mu.localAddressEndpoint.IncRef() { - panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", newRoute.LocalAddress)) +func (r *Route) acquireLocked() { + if ep := r.mu.localAddressEndpoint; ep != nil { + if !ep.IncRef() { + panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress)) } } - newRoute.mu.remoteLinkAddress = r.mu.remoteLinkAddress - - return newRoute } // Stack returns the instance of the Stack that owns this route. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 026d330c4..114643b03 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -29,7 +29,6 @@ import ( "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -1520,7 +1519,7 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t } // GetLinkAddress implements LinkAddressCache.GetLinkAddress. -func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { s.mu.RLock() nic := s.nics[nicID] if nic == nil { @@ -1531,7 +1530,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, onResolve) } // Neighbors returns all IP to MAC address associations. @@ -1547,29 +1546,6 @@ func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) { return nic.neighbors() } -// RemoveWaker removes a waker that has been added when link resolution for -// addr was requested. -func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { - if s.useNeighborCache { - s.mu.RLock() - nic, ok := s.nics[nicID] - s.mu.RUnlock() - - if ok { - nic.removeWaker(addr, waker) - } - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - if nic := s.nics[nicID]; nic == nil { - fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} - s.linkAddrCache.removeWaker(fullAddr, waker) - } -} - // AddStaticNeighbor statically associates an IP address to a MAC address. func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error { s.mu.RLock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 68dcd9e61..856ebf6d4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1602,7 +1602,10 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, &stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { + var wantRoute stack.Route + wantRoute.LocalAddress = header.IPv4Any + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1656,7 +1659,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + var wantRoute stack.Route + wantRoute.LocalAddress = nic1Addr.Address + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1666,7 +1672,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, &stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + wantRoute = stack.Route{} + wantRoute.LocalAddress = nic2Addr.Address + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } @@ -1682,7 +1691,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + wantRoute = stack.Route{} + wantRoute.LocalAddress = nic1Addr.Address + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } } @@ -4274,8 +4286,8 @@ func TestWritePacketToRemote(t *testing.T) { if got, want := pkt.Proto, test.protocol; got != want { t.Fatalf("pkt.Proto = %d, want %d", got, want) } - if got, want := pkt.Route.RemoteLinkAddress(), linkAddr2; got != want { - t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", got, want) + if pkt.Route.RemoteLinkAddress != linkAddr2 { + t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2) } if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" { t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 66eb562ba..dd552b8b9 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -77,6 +77,7 @@ func (f *fakeTransportEndpoint) Abort() { } func (f *fakeTransportEndpoint) Close() { + // TODO(gvisor.dev/issue/5153): Consider retaining the route. f.route.Release() } @@ -146,16 +147,16 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return tcpip.ErrNoRoute } - defer r.Release() // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { + r.Release() return err } - f.route = r.Clone() + f.route = r return nil } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 74fe19e98..d1e4a7cb7 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -504,7 +504,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer r.Release() id := stack.TransportEndpointID{ LocalAddress: r.LocalAddress, @@ -519,11 +518,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { + r.Release() return err } e.ID = id - e.route = r.Clone() + e.route = r e.RegisterNICID = nicID e.state = stateConnected diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 5df703fb2..7befcfc9b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -261,15 +261,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } e.mu.RLock() + defer e.mu.RUnlock() if e.closed { - e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidEndpointState } payloadBytes, err := p.FullPayload() if err != nil { - e.mu.RUnlock() return 0, nil, err } @@ -278,7 +277,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { - e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } dstAddr := ip.DestinationAddress() @@ -300,39 +298,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If the user doesn't specify a destination, they should have // connected to another address. if !e.connected { - e.mu.RUnlock() return 0, nil, tcpip.ErrDestinationRequired } - if e.route.IsResolutionRequired() { - savedRoute := e.route - // Promote lock to exclusive if using a shared route, - // given that it may need to change in finishWrite. - e.mu.RUnlock() - e.mu.Lock() - - // Make sure that the route didn't change during the - // time we didn't hold the lock. - if !e.connected || savedRoute != e.route { - e.mu.Unlock() - return 0, nil, tcpip.ErrInvalidEndpointState - } - - n, ch, err := e.finishWrite(payloadBytes, savedRoute) - e.mu.Unlock() - return n, ch, err - } - - n, ch, err := e.finishWrite(payloadBytes, e.route) - e.mu.RUnlock() - return n, ch, err + return e.finishWrite(payloadBytes, e.route) } // The caller provided a destination. Reject destination address if it // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC if e.bound && nic != 0 && nic != e.BindNICID { - e.mu.RUnlock() return 0, nil, tcpip.ErrNoRoute } @@ -340,13 +315,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // FindRoute will choose an appropriate source address. route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) if err != nil { - e.mu.RUnlock() return 0, nil, err } n, ch, err := e.finishWrite(payloadBytes, route) route.Release() - e.mu.RUnlock() return n, ch, err } @@ -435,11 +408,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer route.Release() if e.associated { // Re-register the endpoint with the appropriate NIC. if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + route.Release() return err } e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) @@ -447,7 +420,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Save the route we've connected via. - e.route = route.Clone() + e.route = route e.connected = true return nil diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 3e1041cbe..2d96a65bd 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -778,7 +778,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() - s := sleep.Sleeper{} + var s sleep.Sleeper s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) for { diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index c944dccc0..0dc710276 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -462,7 +462,7 @@ func (h *handshake) processSegments() *tcpip.Error { func (h *handshake) resolveRoute() *tcpip.Error { // Set up the wakers. - s := sleep.Sleeper{} + var s sleep.Sleeper resolutionWaker := &sleep.Waker{} s.AddWaker(resolutionWaker, wakerForResolution) s.AddWaker(&h.ep.notificationWaker, wakerForNotification) @@ -470,24 +470,27 @@ func (h *handshake) resolveRoute() *tcpip.Error { // Initial action is to resolve route. index := wakerForResolution + attemptedResolution := false for { switch index { case wakerForResolution: - if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { - if err == tcpip.ErrNoLinkAddress { - h.ep.stats.SendErrors.NoLinkAddr.Increment() - } else if err != nil { + if _, err := h.ep.route.Resolve(resolutionWaker.Assert); err != tcpip.ErrWouldBlock { + if err != nil { h.ep.stats.SendErrors.NoRoute.Increment() } // Either success (err == nil) or failure. return err } + if attemptedResolution { + h.ep.stats.SendErrors.NoLinkAddr.Increment() + return tcpip.ErrNoLinkAddress + } + attemptedResolution = true // Resolution not completed. Keep trying... case wakerForNotification: n := h.ep.fetchNotifications() if n¬ifyClose != 0 { - h.ep.route.RemoveWaker(resolutionWaker) return tcpip.ErrAborted } if n¬ifyDrain != 0 { @@ -563,7 +566,7 @@ func (h *handshake) start() *tcpip.Error { // complete completes the TCP 3-way handshake initiated by h.start(). func (h *handshake) complete() *tcpip.Error { // Set up the wakers. - s := sleep.Sleeper{} + var s sleep.Sleeper resendWaker := sleep.Waker{} s.AddWaker(&resendWaker, wakerForResend) s.AddWaker(&h.ep.notificationWaker, wakerForNotification) @@ -1512,7 +1515,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } // Initialize the sleeper based on the wakers in funcs. - s := sleep.Sleeper{} + var s sleep.Sleeper for i := range funcs { s.AddWaker(funcs[i].w, i) } @@ -1699,7 +1702,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { const notification = 2 const timeWaitDone = 3 - s := sleep.Sleeper{} + var s sleep.Sleeper defer s.Done() s.AddWaker(&e.newSegmentWaker, newSegment) s.AddWaker(&e.notificationWaker, notification) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 2128206d7..c88e74bec 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2292,7 +2292,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc e.isRegistered = true e.setEndpointState(StateConnecting) - e.route = r.Clone() + r.Acquire() + e.route = r e.boundNICID = nicID e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index d919fa011..24d0c2cb9 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1017,7 +1017,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer r.Release() id := stack.TransportEndpointID{ LocalAddress: e.ID.LocalAddress, @@ -1045,6 +1044,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { + r.Release() return err } @@ -1055,7 +1055,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.ID = id e.boundBindToDevice = btd - e.route = r.Clone() + e.route = r e.dstPort = addr.Port e.RegisterNICID = nicID e.effectiveNetProtos = netProtos |