From 2d90bc54809766927e6028fac5b9f67cd2a13c3e Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Sat, 30 Jan 2021 11:35:35 -0800 Subject: Implement LinkAddressResolver on NetworkEndpoints This removes the need to provide the link address request with the NIC the request is being performed on since the NetworkEndpoints already have a reference to the NIC. PiperOrigin-RevId: 354721940 --- pkg/tcpip/network/arp/BUILD | 1 - pkg/tcpip/network/arp/arp.go | 59 ++++++----------------- pkg/tcpip/network/arp/arp_test.go | 86 +++++++--------------------------- pkg/tcpip/network/arp/stats_test.go | 42 ----------------- pkg/tcpip/network/ipv6/icmp.go | 35 +++++--------- pkg/tcpip/network/ipv6/icmp_test.go | 27 ++++++----- pkg/tcpip/network/ipv6/ipv6.go | 1 + pkg/tcpip/stack/forwarding_test.go | 18 +++---- pkg/tcpip/stack/linkaddrcache.go | 8 ++-- pkg/tcpip/stack/linkaddrcache_test.go | 44 ++++++++--------- pkg/tcpip/stack/neighbor_cache_test.go | 2 +- pkg/tcpip/stack/neighbor_entry.go | 4 +- pkg/tcpip/stack/neighbor_entry_test.go | 2 +- pkg/tcpip/stack/nic.go | 43 +++++++++++++---- pkg/tcpip/stack/nic_test.go | 20 -------- pkg/tcpip/stack/nud_test.go | 49 +++++++------------ pkg/tcpip/stack/registration.go | 8 +--- pkg/tcpip/stack/route.go | 2 +- pkg/tcpip/stack/stack.go | 18 +------ 19 files changed, 152 insertions(+), 317 deletions(-) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index c7ab876bf..933845269 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -10,7 +10,6 @@ go_library( ], visibility = ["//visibility:public"], deps = [ - "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7838cc753..5c79d6485 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -22,7 +22,6 @@ import ( "reflect" "sync/atomic" - "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -35,6 +34,8 @@ const ( ProtocolNumber = header.ARPProtocolNumber ) +var _ stack.LinkAddressResolver = (*endpoint)(nil) + // ARP endpoints need to implement stack.NetworkEndpoint because the stack // considers the layer above the link-layer a network layer; the only // facility provided by the stack to deliver packets to a layer above @@ -101,9 +102,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.ARPSize } -func (e *endpoint) Close() { - e.protocol.forgetEndpoint(e.nic.ID()) -} +func (*endpoint) Close() {} func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} @@ -154,7 +153,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if e.nud == nil { e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr) } else { - e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol) + e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e) } respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -221,19 +220,10 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { } var _ stack.NetworkProtocol = (*protocol)(nil) -var _ stack.LinkAddressResolver = (*protocol)(nil) // protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { stack *stack.Stack - - mu struct { - sync.RWMutex - - // eps is keyed by NICID to allow protocol methods to retrieve the correct - // endpoint depending on the NIC. - eps map[tcpip.NICID]*endpoint - } } func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } @@ -257,43 +247,26 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L stackStats := p.stack.Stats() e.stats.arp.init(&e.stats.localStats.ARP, &stackStats.ARP) - p.mu.Lock() - p.mu.eps[nic.ID()] = e - p.mu.Unlock() - return e } -func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { - p.mu.Lock() - defer p.mu.Unlock() - delete(p.mu.eps, nicID) -} - // LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. -func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error { - nicID := nic.ID() - - p.mu.Lock() - netEP, ok := p.mu.eps[nicID] - p.mu.Unlock() - if !ok { - return &tcpip.ErrNotConnected{} - } +func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { + nicID := e.nic.ID() - stats := netEP.stats.arp + stats := e.stats.arp if len(remoteLinkAddr) == 0 { remoteLinkAddr = header.EthernetBroadcastAddress } if len(localAddr) == 0 { - addr, ok := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) + addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) if !ok { return &tcpip.ErrUnknownNICID{} } @@ -304,13 +277,13 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } localAddr = addr.Address - } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + } else if e.protocol.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { stats.outgoingRequestBadLocalAddressErrors.Increment() return &tcpip.ErrBadLocalAddress{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize, + ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) pkt.NetworkProtocolNumber = ProtocolNumber @@ -318,14 +291,14 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot h.SetOp(header.ARPRequest) // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a // link address. - _ = copy(h.HardwareAddressSender(), nic.LinkAddress()) + _ = copy(h.HardwareAddressSender(), e.nic.LinkAddress()) if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize { panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize { panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } - if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { stats.outgoingRequestsDropped.Increment() return err } @@ -334,7 +307,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. -func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { return header.EthernetBroadcastAddress, true } @@ -369,9 +342,5 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu func NewProtocol(s *stack.Stack) stack.NetworkProtocol { return &protocol{ stack: s, - mu: struct { - sync.RWMutex - eps map[tcpip.NICID]*endpoint - }{eps: make(map[tcpip.NICID]*endpoint)}, } } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index b0f07aa44..d753a97af 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -530,52 +530,19 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { } } -var _ stack.NetworkInterface = (*testInterface)(nil) +var _ stack.LinkEndpoint = (*testLinkEndpoint)(nil) -type testInterface struct { +type testLinkEndpoint struct { stack.LinkEndpoint - nicID tcpip.NICID - writeErr tcpip.Error } -func (t *testInterface) ID() tcpip.NICID { - return t.nicID -} - -func (*testInterface) IsLoopback() bool { - return false -} - -func (*testInterface) Name() string { - return "" -} - -func (*testInterface) Enabled() bool { - return true -} - -func (*testInterface) Promiscuous() bool { - return false -} - -func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt) -} - -func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) -} - -func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if t.writeErr != nil { return t.writeErr } - var r stack.RouteInfo - r.NetProto = protocol - r.RemoteLinkAddress = remoteLinkAddr return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) } @@ -709,32 +676,31 @@ func TestLinkAddressRequest(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, }) - p := s.NetworkProtocolInstance(arp.ProtocolNumber) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") - } - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - if err := s.CreateNIC(nicID, linkEP); err != nil { + if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } + ep, err := s.GetNetworkEndpoint(nicID, arp.ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, arp.ProtocolNumber, err) + } + linkRes, ok := ep.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep) + } + if len(test.nicAddr) != 0 { if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) } } - // We pass a test network interface to LinkAddressRequest with the same - // NIC ID and link endpoint used by the NIC we created earlier so that we - // can mock a link address request and observe the packets sent to the - // link endpoint even though the stack uses the real NIC to validate the - // local address. - iface := testInterface{LinkEndpoint: linkEP, nicID: nicID, writeErr: test.linkErr} - err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff) + { + err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff) + } } if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent { @@ -782,19 +748,3 @@ func TestLinkAddressRequest(t *testing.T) { }) } } - -func TestLinkAddressRequestWithoutNIC(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, - }) - p := s.NetworkProtocolInstance(arp.ProtocolNumber) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") - } - - err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}) - if _, ok := err.(*tcpip.ErrNotConnected); !ok { - t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, &tcpip.ErrNotConnected{}) - } -} diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go index 036fdf739..d3b56c635 100644 --- a/pkg/tcpip/network/arp/stats_test.go +++ b/pkg/tcpip/network/arp/stats_test.go @@ -34,48 +34,6 @@ func (t *testInterface) ID() tcpip.NICID { return t.nicID } -func knownNICIDs(proto *protocol) []tcpip.NICID { - var nicIDs []tcpip.NICID - - for k := range proto.mu.eps { - nicIDs = append(nicIDs, k) - } - - return nicIDs -} - -func TestClearEndpointFromProtocolOnClose(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - nic := testInterface{nicID: 1} - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) - var nicIDs []tcpip.NICID - - proto.mu.Lock() - foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - - if !hasEndpointBeforeClose { - t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) - } - if foundEP != ep { - t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) - } - - ep.Close() - - proto.mu.Lock() - _, hasEndpointAfterClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - if hasEndpointAfterClose { - t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } -} - func TestMultiCounterStatsInitialization(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 7298bd061..8db2454d3 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -290,7 +290,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { received.invalid.Increment() return } else if e.nud != nil { - e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e) } else { e.linkAddrCache.AddLinkAddress(srcAddr, sourceLinkAddr) } @@ -578,7 +578,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { if e.nud != nil { // A RS with a specified source IP address modifies the NUD state // machine in the same way a reachability probe would. - e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e) } } @@ -628,7 +628,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // If the RA has the source link layer option, update the link address // cache with the link address for the advertised router. if len(sourceLinkAddr) != 0 && e.nud != nil { - e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e) } e.mu.Lock() @@ -694,24 +694,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } } -var _ stack.LinkAddressResolver = (*protocol)(nil) - // LinkAddressProtocol implements stack.LinkAddressResolver. -func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv6ProtocolNumber } // LinkAddressRequest implements stack.LinkAddressResolver. -func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error { - nicID := nic.ID() - - p.mu.Lock() - netEP, ok := p.mu.eps[nicID] - p.mu.Unlock() - if !ok { - return &tcpip.ErrNotConnected{} - } - +func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { remoteAddr := targetAddr if len(remoteLinkAddr) == 0 { remoteAddr = header.SolicitedNodeAddr(targetAddr) @@ -719,22 +708,22 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } if len(localAddr) == 0 { - addressEndpoint := netEP.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */) + addressEndpoint := e.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */) if addressEndpoint == nil { return &tcpip.ErrNetworkUnreachable{} } localAddr = addressEndpoint.AddressWithPrefix().Address - } else if p.stack.CheckLocalAddress(nicID, ProtocolNumber, localAddr) == 0 { + } else if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, localAddr) == 0 { return &tcpip.ErrBadLocalAddress{} } optsSerializer := header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()), + header.NDPSourceLinkLayerAddressOption(e.nic.LinkAddress()), } neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize, + ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize, }) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) @@ -751,9 +740,9 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot panic(fmt.Sprintf("failed to add IP header: %s", err)) } - stat := netEP.stats.icmp.packetsSent + stat := e.stats.icmp.packetsSent - if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { stat.dropped.Increment() return err } @@ -763,7 +752,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } // ResolveStaticAddress implements stack.LinkAddressResolver. -func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if header.IsV6MulticastAddress(addr) { return header.EthernetAddressFromMulticastIPv6Address(addr), true } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index db1c2e663..a5c88444e 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -1346,29 +1346,32 @@ func TestLinkAddressRequest(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) - p := s.NetworkProtocolInstance(ProtocolNumber) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") - } linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) if err := s.CreateNIC(nicID, linkEP); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } + + ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err) + } + linkRes, ok := ep.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep) + } + if len(test.nicAddr) != 0 { if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) } } - // We pass a test network interface to LinkAddressRequest with the same NIC - // ID and link endpoint used by the NIC we created earlier so that we can - // mock a link address request and observe the packets sent to the link - // endpoint even though the stack uses the real NIC. - err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff) + { + err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff) + } } if test.expectedErr != nil { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index caa62b3a2..b55a35525 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -164,6 +164,7 @@ func getLabel(addr tcpip.Address) uint8 { panic(fmt.Sprintf("should have a label for address = %s", addr)) } +var _ stack.LinkAddressResolver = (*endpoint)(nil) var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 63a42a2ea..1e4ddf0d5 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -41,6 +41,7 @@ const ( protocolNumberOffset = 2 ) +var _ LinkAddressResolver = (*fwdTestNetworkEndpoint)(nil) var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) // fwdTestNetworkEndpoint is a network-layer protocol endpoint. @@ -153,7 +154,6 @@ type fwdTestNetworkEndpointStats struct{} // IsNetworkEndpointStats implements stack.NetworkEndpointStats. func (*fwdTestNetworkEndpointStats) IsNetworkEndpointStats() {} -var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) // fwdTestNetworkProtocol is a network-layer protocol that implements Address @@ -219,23 +219,23 @@ func (*fwdTestNetworkProtocol) Close() {} func (*fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { - if f.onLinkAddressResolved != nil { - time.AfterFunc(f.addrResolveDelay, func() { - f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) +func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { + if fn := f.proto.onLinkAddressResolved; fn != nil { + time.AfterFunc(f.proto.addrResolveDelay, func() { + fn(f.proto.addrCache, f.proto.neigh, addr, remoteLinkAddr) }) } return nil } -func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if f.onResolveStaticAddress != nil { - return f.onResolveStaticAddress(addr) +func (f *fwdTestNetworkEndpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if fn := f.proto.onResolveStaticAddress; fn != nil { + return fn(addr) } return "", false } -func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return fwdTestNetNumber } diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 930b8f795..cd2bb3417 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -199,7 +199,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { c.mu.Lock() entry := c.getOrCreateEntryLocked(k) entry.mu.Lock() @@ -224,7 +224,7 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA } if entry.mu.done == nil { entry.mu.done = make(chan struct{}) - go c.startAddressResolution(k, linkRes, localAddr, nic, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + go c.startAddressResolution(k, linkRes, localAddr, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{} default: @@ -232,11 +232,11 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA } } -func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */, nic) + linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */) select { case now := <-time.After(c.resolutionTimeout): diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 466a5e8d9..40017c8b6 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -48,7 +48,7 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { +func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error { // TODO(gvisor.dev/issue/5141): Use a fake clock. time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { @@ -80,7 +80,7 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolver) (tcpip.LinkAddress, tcpip.Error) { var attemptedResolution bool for { - got, ch, err := c.get(addr, linkRes, "", nil, nil) + got, ch, err := c.get(addr, linkRes, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); ok { if attemptedResolution { return got, &tcpip.ErrTimeout{} @@ -104,23 +104,23 @@ func TestCacheOverflow(t *testing.T) { for i := len(testAddrs) - 1; i >= 0; i-- { e := testAddrs[i] c.AddLinkAddress(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, nil, "", nil) if err != nil { - t.Errorf("insert %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) + t.Errorf("insert %d, c.get(%s, nil, '', nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("insert %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) + t.Errorf("insert %d, got c.get(%s, nil, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // Expect to find at least half of the most recent entries. for i := 0; i < linkAddrCacheSize/2; i++ { e := testAddrs[i] - got, _, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, nil, "", nil) if err != nil { - t.Errorf("check %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) + t.Errorf("check %d, c.get(%s, nil, '', nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("check %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) + t.Errorf("check %d, got c.get(%s, nil, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // The earliest entries should no longer be in the cache. @@ -154,12 +154,12 @@ func TestCacheConcurrent(t *testing.T) { // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, linkRes, "", nil, nil) + got, _, err := c.get(e.addr, linkRes, "", nil) if err != nil { - t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, _, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, _, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } e = testAddrs[0] @@ -177,9 +177,9 @@ func TestCacheAgeLimit(t *testing.T) { e := testAddrs[0] c.AddLinkAddress(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - _, _, err := c.get(e.addr, linkRes, "", nil, nil) + _, _, err := c.get(e.addr, linkRes, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = ErrWouldBlock", e.addr, err) + t.Errorf("got c.get(%s, _, '', nil) = %s, want = ErrWouldBlock", e.addr, err) } } @@ -188,21 +188,21 @@ func TestCacheReplace(t *testing.T) { e := testAddrs[0] l2 := e.linkAddr + "2" c.AddLinkAddress(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, nil, "", nil) if err != nil { - t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, nil, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, nil, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } c.AddLinkAddress(e.addr, l2) - got, _, err = c.get(e.addr, nil, "", nil, nil) + got, _, err = c.get(e.addr, nil, "", nil) if err != nil { - t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, nil, '', nil): %s", e.addr, err) } if got != l2 { - t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, l2) + t.Errorf("got c.get(%s, nil, '', nil) = %s, want = %s", e.addr, got, l2) } } @@ -228,12 +228,12 @@ func TestCacheResolution(t *testing.T) { // Check that after resolved, address stays in the cache and never returns WouldBlock. for i := 0; i < 10; i++ { e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, linkRes, "", nil, nil) + got, _, err := c.get(e.addr, linkRes, "", nil) if err != nil { - t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, _, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, _, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } } } diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 2870e4f66..0f7925774 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -194,7 +194,7 @@ type testNeighborResolver struct { var _ LinkAddressResolver = (*testNeighborResolver)(nil) -func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { +func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error { if !r.dropReplies { // Delay handling the request to emulate network latency. r.clock.AfterFunc(r.delay, func() { diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 53ac9bb6e..a037ca6f9 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -254,7 +254,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr, e.nic); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return @@ -340,7 +340,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // address SHOULD be placed in the IP Source Address of the outgoing // solicitation. // - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, "", e.nic); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, ""); err != nil { // There is no need to log the error here; the NUD implementation may // assume a working link. A valid link should be the responsibility of // the NIC/stack.LinkEndpoint. diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 140b8ca00..c5c3d266b 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -193,7 +193,7 @@ func (p entryTestProbeInfo) String() string { // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts // to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { +func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { p := entryTestProbeInfo{ RemoteAddress: targetAddr, RemoteLinkAddress: linkAddr, diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index e56a624fe..0707c3ce2 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -42,7 +42,8 @@ type NIC struct { // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. - networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint + networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint + linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. // @@ -133,12 +134,13 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic := &NIC{ LinkEndpoint: ep, - stack: stack, - id: id, - name: name, - context: ctx, - stats: makeNICStats(), - networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), + stack: stack, + id: id, + name: name, + context: ctx, + stats: makeNICStats(), + networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), } nic.linkResQueue.init(nic) nic.linkAddrCache = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts) @@ -146,7 +148,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // Check for Neighbor Unreachability Detection support. var nud NUDHandler - if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 && stack.useNeighborCache { + if ep.Capabilities()&CapabilityResolutionRequired != 0 && stack.useNeighborCache { rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds())) nic.neigh = &neighborCache{ nic: nic, @@ -170,7 +172,13 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC for _, netProto := range stack.networkProtocols { netNum := netProto.Number() nic.mu.packetEPs[netNum] = new(packetEndpointList) - nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, nic.linkAddrCache, nud, nic) + + netEP := netProto.NewEndpoint(nic, nic.linkAddrCache, nud, nic) + nic.networkEndpoints[netNum] = netEP + + if r, ok := netEP.(LinkAddressResolver); ok { + nic.linkAddrResolvers[r.LinkAddressProtocol()] = r + } } nic.LinkEndpoint.Attach(nic) @@ -593,13 +601,28 @@ func (n *NIC) confirmReachable(addr tcpip.Address) { } } +func (n *NIC) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error { + linkRes, ok := n.linkAddrResolvers[protocol] + if !ok { + return &tcpip.ErrNotSupported{} + } + + if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok { + onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true}) + return nil + } + + _, _, err := n.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) + return err +} + func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { if n.neigh != nil { entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve) return entry.LinkAddr, ch, err } - return n.linkAddrCache.get(addr, linkRes, localAddr, n, onResolve) + return n.linkAddrCache.get(addr, linkRes, localAddr, onResolve) } func (n *NIC) neighbors() ([]NeighborEntry, tcpip.Error) { diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 2f719fbe5..3564202d8 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -111,8 +111,6 @@ type testIPv6EndpointStats struct{} // IsNetworkEndpointStats implements stack.NetworkEndpointStats. func (*testIPv6EndpointStats) IsNetworkEndpointStats() {} -var _ LinkAddressResolver = (*testIPv6Protocol)(nil) - // We use this instead of ipv6.protocol because the ipv6 package depends on // the stack package which this test lives in, causing a cyclic dependency. type testIPv6Protocol struct{} @@ -169,24 +167,6 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo return 0, false, false } -// LinkAddressProtocol implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return header.IPv6ProtocolNumber -} - -// LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { - return nil -} - -// ResolveStaticAddress implements LinkAddressResolver. -func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if header.IsV6MulticastAddress(addr) { - return header.EthernetAddressFromMulticastIPv6Address(addr), true - } - return "", false -} - func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index ebfd5eb45..504acc246 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -72,17 +72,17 @@ func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { } // TestNUDConfigurationFailsForNotSupported tests to make sure we get a -// NotSupported error if we attempt to retrieve NUD configurations when the -// stack doesn't support NUD. +// NotSupported error if we attempt to retrieve or set NUD configurations when +// the stack doesn't support NUD. // // The stack will report to not support NUD if a neighbor cache for a given NIC // is not allocated. The networking stack will only allocate neighbor caches if -// a protocol providing link address resolution is specified (e.g. ARP, IPv6). +// the NIC requires link resolution. func TestNUDConfigurationFailsForNotSupported(t *testing.T) { const nicID = 1 e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + e.LinkEPCapabilities &^= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ NUDConfigs: stack.DefaultNUDConfigurations(), @@ -91,38 +91,21 @@ func TestNUDConfigurationFailsForNotSupported(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - _, err := s.NUDConfigurations(nicID) - if _, ok := err.(*tcpip.ErrNotSupported); !ok { - t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, &tcpip.ErrNotSupported{}) - } -} -// TestNUDConfigurationFailsForNotSupported tests to make sure we get a -// NotSupported error if we attempt to set NUD configurations when the stack -// doesn't support NUD. -// -// The stack will report to not support NUD if a neighbor cache for a given NIC -// is not allocated. The networking stack will only allocate neighbor caches if -// a protocol providing link address resolution is specified (e.g. ARP, IPv6). -func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) { - const nicID = 1 - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), - UseNeighborCache: true, + t.Run("Get", func(t *testing.T) { + _, err := s.NUDConfigurations(nicID) + if _, ok := err.(*tcpip.ErrNotSupported); !ok { + t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, &tcpip.ErrNotSupported{}) + } }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - config := stack.NUDConfigurations{} - err := s.SetNUDConfigurations(nicID, config) - if _, ok := err.(*tcpip.ErrNotSupported); !ok { - t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, &tcpip.ErrNotSupported{}) - } + t.Run("Set", func(t *testing.T) { + config := stack.NUDConfigurations{} + err := s.SetNUDConfigurations(nicID, config) + if _, ok := err.(*tcpip.ErrNotSupported); !ok { + t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, &tcpip.ErrNotSupported{}) + } + }) } // TestDefaultNUDConfigurationIsValid verifies that calling diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 510da8689..64b5627e1 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -824,16 +824,12 @@ type InjectableLinkEndpoint interface { InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error } -// A LinkAddressResolver is an extension to a NetworkProtocol that -// can resolve link addresses. +// A LinkAddressResolver handles link address resolution for a network protocol. type LinkAddressResolver interface { // LinkAddressRequest sends a request for the link address of the target // address. The request is broadcasted on the local network if a remote link // address is not provided. - // - // The request is sent from the passed network interface. If the interface - // local address is unspecified, any interface local address may be used. - LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) tcpip.Error + LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 4ae0f2a1a..1c8ef6ed4 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -174,7 +174,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA } if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { + if linkRes, ok := r.outgoingNIC.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 119c4c505..e720d676f 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -376,7 +376,6 @@ func (u *uniqueIDGenerator) UniqueID() uint64 { type Stack struct { transportProtocols map[tcpip.TransportProtocolNumber]*transportProtocolState networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol - linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver // rawFactory creates raw endpoints. If nil, raw endpoints are // disabled. It is set during Stack creation and is immutable. @@ -635,7 +634,6 @@ func New(opts Options) *Stack { s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), nics: make(map[tcpip.NICID]*NIC), cleanupEndpoints: make(map[TransportEndpoint]struct{}), PortManager: ports.NewPortManager(), @@ -666,9 +664,6 @@ func New(opts Options) *Stack { for _, netProtoFactory := range opts.NetworkProtocols { netProto := netProtoFactory(s) s.networkProtocols[netProto.Number()] = netProto - if r, ok := netProto.(LinkAddressResolver); ok { - s.linkAddrResolvers[r.LinkAddressProtocol()] = r - } } // Add specified transport protocols. @@ -1561,18 +1556,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, return &tcpip.ErrUnknownNICID{} } - linkRes, ok := s.linkAddrResolvers[protocol] - if !ok { - return &tcpip.ErrNotSupported{} - } - - if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok { - onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true}) - return nil - } - - _, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) - return err + return nic.getLinkAddress(addr, localAddr, protocol, onResolve) } // Neighbors returns all IP to MAC address associations. -- cgit v1.2.3