summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/network/arp/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp.go59
-rw-r--r--pkg/tcpip/network/arp/arp_test.go86
-rw-r--r--pkg/tcpip/network/arp/stats_test.go42
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go35
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go27
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go1
-rw-r--r--pkg/tcpip/stack/forwarding_test.go18
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go8
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go44
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go2
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go4
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go2
-rw-r--r--pkg/tcpip/stack/nic.go43
-rw-r--r--pkg/tcpip/stack/nic_test.go20
-rw-r--r--pkg/tcpip/stack/nud_test.go49
-rw-r--r--pkg/tcpip/stack/registration.go8
-rw-r--r--pkg/tcpip/stack/route.go2
-rw-r--r--pkg/tcpip/stack/stack.go18
19 files changed, 152 insertions, 317 deletions
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index c7ab876bf..933845269 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -10,7 +10,6 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
- "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 7838cc753..5c79d6485 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -22,7 +22,6 @@ import (
"reflect"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -35,6 +34,8 @@ const (
ProtocolNumber = header.ARPProtocolNumber
)
+var _ stack.LinkAddressResolver = (*endpoint)(nil)
+
// ARP endpoints need to implement stack.NetworkEndpoint because the stack
// considers the layer above the link-layer a network layer; the only
// facility provided by the stack to deliver packets to a layer above
@@ -101,9 +102,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.ARPSize
}
-func (e *endpoint) Close() {
- e.protocol.forgetEndpoint(e.nic.ID())
-}
+func (*endpoint) Close() {}
func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error {
return &tcpip.ErrNotSupported{}
@@ -154,7 +153,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if e.nud == nil {
e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr)
} else {
- e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
+ e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e)
}
respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -221,19 +220,10 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
}
var _ stack.NetworkProtocol = (*protocol)(nil)
-var _ stack.LinkAddressResolver = (*protocol)(nil)
// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
type protocol struct {
stack *stack.Stack
-
- mu struct {
- sync.RWMutex
-
- // eps is keyed by NICID to allow protocol methods to retrieve the correct
- // endpoint depending on the NIC.
- eps map[tcpip.NICID]*endpoint
- }
}
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
@@ -257,43 +247,26 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
stackStats := p.stack.Stats()
e.stats.arp.init(&e.stats.localStats.ARP, &stackStats.ARP)
- p.mu.Lock()
- p.mu.eps[nic.ID()] = e
- p.mu.Unlock()
-
return e
}
-func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
- p.mu.Lock()
- defer p.mu.Unlock()
- delete(p.mu.eps, nicID)
-}
-
// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
-func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error {
- nicID := nic.ID()
-
- p.mu.Lock()
- netEP, ok := p.mu.eps[nicID]
- p.mu.Unlock()
- if !ok {
- return &tcpip.ErrNotConnected{}
- }
+func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
+ nicID := e.nic.ID()
- stats := netEP.stats.arp
+ stats := e.stats.arp
if len(remoteLinkAddr) == 0 {
remoteLinkAddr = header.EthernetBroadcastAddress
}
if len(localAddr) == 0 {
- addr, ok := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
+ addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
if !ok {
return &tcpip.ErrUnknownNICID{}
}
@@ -304,13 +277,13 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
localAddr = addr.Address
- } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ } else if e.protocol.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
stats.outgoingRequestBadLocalAddressErrors.Increment()
return &tcpip.ErrBadLocalAddress{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize,
+ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
})
h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
pkt.NetworkProtocolNumber = ProtocolNumber
@@ -318,14 +291,14 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
h.SetOp(header.ARPRequest)
// TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a
// link address.
- _ = copy(h.HardwareAddressSender(), nic.LinkAddress())
+ _ = copy(h.HardwareAddressSender(), e.nic.LinkAddress())
if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize {
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
}
if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize {
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
}
- if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
stats.outgoingRequestsDropped.Increment()
return err
}
@@ -334,7 +307,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
-func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast {
return header.EthernetBroadcastAddress, true
}
@@ -369,9 +342,5 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
return &protocol{
stack: s,
- mu: struct {
- sync.RWMutex
- eps map[tcpip.NICID]*endpoint
- }{eps: make(map[tcpip.NICID]*endpoint)},
}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index b0f07aa44..d753a97af 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -530,52 +530,19 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
}
}
-var _ stack.NetworkInterface = (*testInterface)(nil)
+var _ stack.LinkEndpoint = (*testLinkEndpoint)(nil)
-type testInterface struct {
+type testLinkEndpoint struct {
stack.LinkEndpoint
- nicID tcpip.NICID
-
writeErr tcpip.Error
}
-func (t *testInterface) ID() tcpip.NICID {
- return t.nicID
-}
-
-func (*testInterface) IsLoopback() bool {
- return false
-}
-
-func (*testInterface) Name() string {
- return ""
-}
-
-func (*testInterface) Enabled() bool {
- return true
-}
-
-func (*testInterface) Promiscuous() bool {
- return false
-}
-
-func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt)
-}
-
-func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol)
-}
-
-func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if t.writeErr != nil {
return t.writeErr
}
- var r stack.RouteInfo
- r.NetProto = protocol
- r.RemoteLinkAddress = remoteLinkAddr
return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt)
}
@@ -709,32 +676,31 @@ func TestLinkAddressRequest(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
})
- p := s.NetworkProtocolInstance(arp.ProtocolNumber)
- linkRes, ok := p.(stack.LinkAddressResolver)
- if !ok {
- t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
- }
-
linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
- if err := s.CreateNIC(nicID, linkEP); err != nil {
+ if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
+ ep, err := s.GetNetworkEndpoint(nicID, arp.ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, arp.ProtocolNumber, err)
+ }
+ linkRes, ok := ep.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep)
+ }
+
if len(test.nicAddr) != 0 {
if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil {
t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err)
}
}
- // We pass a test network interface to LinkAddressRequest with the same
- // NIC ID and link endpoint used by the NIC we created earlier so that we
- // can mock a link address request and observe the packets sent to the
- // link endpoint even though the stack uses the real NIC to validate the
- // local address.
- iface := testInterface{LinkEndpoint: linkEP, nicID: nicID, writeErr: test.linkErr}
- err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface)
- if diff := cmp.Diff(test.expectedErr, err); diff != "" {
- t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff)
+ {
+ err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr)
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff)
+ }
}
if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent {
@@ -782,19 +748,3 @@ func TestLinkAddressRequest(t *testing.T) {
})
}
}
-
-func TestLinkAddressRequestWithoutNIC(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
- })
- p := s.NetworkProtocolInstance(arp.ProtocolNumber)
- linkRes, ok := p.(stack.LinkAddressResolver)
- if !ok {
- t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
- }
-
- err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID})
- if _, ok := err.(*tcpip.ErrNotConnected); !ok {
- t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, &tcpip.ErrNotConnected{})
- }
-}
diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go
index 036fdf739..d3b56c635 100644
--- a/pkg/tcpip/network/arp/stats_test.go
+++ b/pkg/tcpip/network/arp/stats_test.go
@@ -34,48 +34,6 @@ func (t *testInterface) ID() tcpip.NICID {
return t.nicID
}
-func knownNICIDs(proto *protocol) []tcpip.NICID {
- var nicIDs []tcpip.NICID
-
- for k := range proto.mu.eps {
- nicIDs = append(nicIDs, k)
- }
-
- return nicIDs
-}
-
-func TestClearEndpointFromProtocolOnClose(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- nic := testInterface{nicID: 1}
- ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
- var nicIDs []tcpip.NICID
-
- proto.mu.Lock()
- foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
-
- if !hasEndpointBeforeClose {
- t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs)
- }
- if foundEP != ep {
- t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID())
- }
-
- ep.Close()
-
- proto.mu.Lock()
- _, hasEndpointAfterClose := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
- if hasEndpointAfterClose {
- t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
- }
-}
-
func TestMultiCounterStatsInitialization(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 7298bd061..8db2454d3 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -290,7 +290,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
received.invalid.Increment()
return
} else if e.nud != nil {
- e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e)
} else {
e.linkAddrCache.AddLinkAddress(srcAddr, sourceLinkAddr)
}
@@ -578,7 +578,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
if e.nud != nil {
// A RS with a specified source IP address modifies the NUD state
// machine in the same way a reachability probe would.
- e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e.protocol)
+ e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e)
}
}
@@ -628,7 +628,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// If the RA has the source link layer option, update the link address
// cache with the link address for the advertised router.
if len(sourceLinkAddr) != 0 && e.nud != nil {
- e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e.protocol)
+ e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e)
}
e.mu.Lock()
@@ -694,24 +694,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
}
-var _ stack.LinkAddressResolver = (*protocol)(nil)
-
// LinkAddressProtocol implements stack.LinkAddressResolver.
-func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv6ProtocolNumber
}
// LinkAddressRequest implements stack.LinkAddressResolver.
-func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error {
- nicID := nic.ID()
-
- p.mu.Lock()
- netEP, ok := p.mu.eps[nicID]
- p.mu.Unlock()
- if !ok {
- return &tcpip.ErrNotConnected{}
- }
-
+func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
remoteAddr := targetAddr
if len(remoteLinkAddr) == 0 {
remoteAddr = header.SolicitedNodeAddr(targetAddr)
@@ -719,22 +708,22 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
if len(localAddr) == 0 {
- addressEndpoint := netEP.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */)
+ addressEndpoint := e.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */)
if addressEndpoint == nil {
return &tcpip.ErrNetworkUnreachable{}
}
localAddr = addressEndpoint.AddressWithPrefix().Address
- } else if p.stack.CheckLocalAddress(nicID, ProtocolNumber, localAddr) == 0 {
+ } else if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, localAddr) == 0 {
return &tcpip.ErrBadLocalAddress{}
}
optsSerializer := header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()),
+ header.NDPSourceLinkLayerAddressOption(e.nic.LinkAddress()),
}
neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize,
+ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize,
})
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
@@ -751,9 +740,9 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
panic(fmt.Sprintf("failed to add IP header: %s", err))
}
- stat := netEP.stats.icmp.packetsSent
+ stat := e.stats.icmp.packetsSent
- if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
stat.dropped.Increment()
return err
}
@@ -763,7 +752,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
-func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if header.IsV6MulticastAddress(addr) {
return header.EthernetAddressFromMulticastIPv6Address(addr), true
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index db1c2e663..a5c88444e 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -1346,29 +1346,32 @@ func TestLinkAddressRequest(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
- p := s.NetworkProtocolInstance(ProtocolNumber)
- linkRes, ok := p.(stack.LinkAddressResolver)
- if !ok {
- t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
- }
linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
if err := s.CreateNIC(nicID, linkEP); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
+
+ ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err)
+ }
+ linkRes, ok := ep.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep)
+ }
+
if len(test.nicAddr) != 0 {
if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
}
}
- // We pass a test network interface to LinkAddressRequest with the same NIC
- // ID and link endpoint used by the NIC we created earlier so that we can
- // mock a link address request and observe the packets sent to the link
- // endpoint even though the stack uses the real NIC.
- err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID})
- if diff := cmp.Diff(test.expectedErr, err); diff != "" {
- t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff)
+ {
+ err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr)
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff)
+ }
}
if test.expectedErr != nil {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index caa62b3a2..b55a35525 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -164,6 +164,7 @@ func getLabel(addr tcpip.Address) uint8 {
panic(fmt.Sprintf("should have a label for address = %s", addr))
}
+var _ stack.LinkAddressResolver = (*endpoint)(nil)
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 63a42a2ea..1e4ddf0d5 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -41,6 +41,7 @@ const (
protocolNumberOffset = 2
)
+var _ LinkAddressResolver = (*fwdTestNetworkEndpoint)(nil)
var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
@@ -153,7 +154,6 @@ type fwdTestNetworkEndpointStats struct{}
// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
func (*fwdTestNetworkEndpointStats) IsNetworkEndpointStats() {}
-var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil)
var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil)
// fwdTestNetworkProtocol is a network-layer protocol that implements Address
@@ -219,23 +219,23 @@ func (*fwdTestNetworkProtocol) Close() {}
func (*fwdTestNetworkProtocol) Wait() {}
-func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error {
- if f.onLinkAddressResolved != nil {
- time.AfterFunc(f.addrResolveDelay, func() {
- f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr)
+func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
+ if fn := f.proto.onLinkAddressResolved; fn != nil {
+ time.AfterFunc(f.proto.addrResolveDelay, func() {
+ fn(f.proto.addrCache, f.proto.neigh, addr, remoteLinkAddr)
})
}
return nil
}
-func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
- if f.onResolveStaticAddress != nil {
- return f.onResolveStaticAddress(addr)
+func (f *fwdTestNetworkEndpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if fn := f.proto.onResolveStaticAddress; fn != nil {
+ return fn(addr)
}
return "", false
}
-func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return fwdTestNetNumber
}
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 930b8f795..cd2bb3417 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -199,7 +199,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry {
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
+func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
c.mu.Lock()
entry := c.getOrCreateEntryLocked(k)
entry.mu.Lock()
@@ -224,7 +224,7 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA
}
if entry.mu.done == nil {
entry.mu.done = make(chan struct{})
- go c.startAddressResolution(k, linkRes, localAddr, nic, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ go c.startAddressResolution(k, linkRes, localAddr, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
}
return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{}
default:
@@ -232,11 +232,11 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA
}
}
-func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
+func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, done <-chan struct{}) {
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */, nic)
+ linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */)
select {
case now := <-time.After(c.resolutionTimeout):
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 466a5e8d9..40017c8b6 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -48,7 +48,7 @@ type testLinkAddressResolver struct {
onLinkAddressRequest func()
}
-func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error {
+func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error {
// TODO(gvisor.dev/issue/5141): Use a fake clock.
time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) })
if f := r.onLinkAddressRequest; f != nil {
@@ -80,7 +80,7 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe
func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolver) (tcpip.LinkAddress, tcpip.Error) {
var attemptedResolution bool
for {
- got, ch, err := c.get(addr, linkRes, "", nil, nil)
+ got, ch, err := c.get(addr, linkRes, "", nil)
if _, ok := err.(*tcpip.ErrWouldBlock); ok {
if attemptedResolution {
return got, &tcpip.ErrTimeout{}
@@ -104,23 +104,23 @@ func TestCacheOverflow(t *testing.T) {
for i := len(testAddrs) - 1; i >= 0; i-- {
e := testAddrs[i]
c.AddLinkAddress(e.addr, e.linkAddr)
- got, _, err := c.get(e.addr, nil, "", nil, nil)
+ got, _, err := c.get(e.addr, nil, "", nil)
if err != nil {
- t.Errorf("insert %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err)
+ t.Errorf("insert %d, c.get(%s, nil, '', nil): %s", i, e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("insert %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
+ t.Errorf("insert %d, got c.get(%s, nil, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
}
}
// Expect to find at least half of the most recent entries.
for i := 0; i < linkAddrCacheSize/2; i++ {
e := testAddrs[i]
- got, _, err := c.get(e.addr, nil, "", nil, nil)
+ got, _, err := c.get(e.addr, nil, "", nil)
if err != nil {
- t.Errorf("check %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err)
+ t.Errorf("check %d, c.get(%s, nil, '', nil): %s", i, e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("check %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
+ t.Errorf("check %d, got c.get(%s, nil, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
}
}
// The earliest entries should no longer be in the cache.
@@ -154,12 +154,12 @@ func TestCacheConcurrent(t *testing.T) {
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
e := testAddrs[len(testAddrs)-1]
- got, _, err := c.get(e.addr, linkRes, "", nil, nil)
+ got, _, err := c.get(e.addr, linkRes, "", nil)
if err != nil {
- t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err)
+ t.Errorf("c.get(%s, _, '', nil): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr)
+ t.Errorf("got c.get(%s, _, '', nil) = %s, want = %s", e.addr, got, e.linkAddr)
}
e = testAddrs[0]
@@ -177,9 +177,9 @@ func TestCacheAgeLimit(t *testing.T) {
e := testAddrs[0]
c.AddLinkAddress(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
- _, _, err := c.get(e.addr, linkRes, "", nil, nil)
+ _, _, err := c.get(e.addr, linkRes, "", nil)
if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = ErrWouldBlock", e.addr, err)
+ t.Errorf("got c.get(%s, _, '', nil) = %s, want = ErrWouldBlock", e.addr, err)
}
}
@@ -188,21 +188,21 @@ func TestCacheReplace(t *testing.T) {
e := testAddrs[0]
l2 := e.linkAddr + "2"
c.AddLinkAddress(e.addr, e.linkAddr)
- got, _, err := c.get(e.addr, nil, "", nil, nil)
+ got, _, err := c.get(e.addr, nil, "", nil)
if err != nil {
- t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err)
+ t.Errorf("c.get(%s, nil, '', nil): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr)
+ t.Errorf("got c.get(%s, nil, '', nil) = %s, want = %s", e.addr, got, e.linkAddr)
}
c.AddLinkAddress(e.addr, l2)
- got, _, err = c.get(e.addr, nil, "", nil, nil)
+ got, _, err = c.get(e.addr, nil, "", nil)
if err != nil {
- t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err)
+ t.Errorf("c.get(%s, nil, '', nil): %s", e.addr, err)
}
if got != l2 {
- t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, l2)
+ t.Errorf("got c.get(%s, nil, '', nil) = %s, want = %s", e.addr, got, l2)
}
}
@@ -228,12 +228,12 @@ func TestCacheResolution(t *testing.T) {
// Check that after resolved, address stays in the cache and never returns WouldBlock.
for i := 0; i < 10; i++ {
e := testAddrs[len(testAddrs)-1]
- got, _, err := c.get(e.addr, linkRes, "", nil, nil)
+ got, _, err := c.get(e.addr, linkRes, "", nil)
if err != nil {
- t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err)
+ t.Errorf("c.get(%s, _, '', nil): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr)
+ t.Errorf("got c.get(%s, _, '', nil) = %s, want = %s", e.addr, got, e.linkAddr)
}
}
}
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index 2870e4f66..0f7925774 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -194,7 +194,7 @@ type testNeighborResolver struct {
var _ LinkAddressResolver = (*testNeighborResolver)(nil)
-func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error {
+func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error {
if !r.dropReplies {
// Delay handling the request to emulate network latency.
r.clock.AfterFunc(r.delay, func() {
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 53ac9bb6e..a037ca6f9 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -254,7 +254,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr, e.nic); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr); err != nil {
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
@@ -340,7 +340,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
// address SHOULD be placed in the IP Source Address of the outgoing
// solicitation.
//
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, "", e.nic); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, ""); err != nil {
// There is no need to log the error here; the NUD implementation may
// assume a working link. A valid link should be the responsibility of
// the NIC/stack.LinkEndpoint.
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index 140b8ca00..c5c3d266b 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -193,7 +193,7 @@ func (p entryTestProbeInfo) String() string {
// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
// to the local network if linkAddr is the zero value.
-func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error {
+func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error {
p := entryTestProbeInfo{
RemoteAddress: targetAddr,
RemoteLinkAddress: linkAddr,
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index e56a624fe..0707c3ce2 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -42,7 +42,8 @@ type NIC struct {
// The network endpoints themselves may be modified by calling the interface's
// methods, but the map reference and entries must be constant.
- networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint
+ networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint
+ linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
// enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
//
@@ -133,12 +134,13 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
nic := &NIC{
LinkEndpoint: ep,
- stack: stack,
- id: id,
- name: name,
- context: ctx,
- stats: makeNICStats(),
- networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
+ stack: stack,
+ id: id,
+ name: name,
+ context: ctx,
+ stats: makeNICStats(),
+ networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
+ linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
}
nic.linkResQueue.init(nic)
nic.linkAddrCache = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts)
@@ -146,7 +148,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// Check for Neighbor Unreachability Detection support.
var nud NUDHandler
- if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 && stack.useNeighborCache {
+ if ep.Capabilities()&CapabilityResolutionRequired != 0 && stack.useNeighborCache {
rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds()))
nic.neigh = &neighborCache{
nic: nic,
@@ -170,7 +172,13 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
nic.mu.packetEPs[netNum] = new(packetEndpointList)
- nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, nic.linkAddrCache, nud, nic)
+
+ netEP := netProto.NewEndpoint(nic, nic.linkAddrCache, nud, nic)
+ nic.networkEndpoints[netNum] = netEP
+
+ if r, ok := netEP.(LinkAddressResolver); ok {
+ nic.linkAddrResolvers[r.LinkAddressProtocol()] = r
+ }
}
nic.LinkEndpoint.Attach(nic)
@@ -593,13 +601,28 @@ func (n *NIC) confirmReachable(addr tcpip.Address) {
}
}
+func (n *NIC) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error {
+ linkRes, ok := n.linkAddrResolvers[protocol]
+ if !ok {
+ return &tcpip.ErrNotSupported{}
+ }
+
+ if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok {
+ onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true})
+ return nil
+ }
+
+ _, _, err := n.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve)
+ return err
+}
+
func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
if n.neigh != nil {
entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve)
return entry.LinkAddr, ch, err
}
- return n.linkAddrCache.get(addr, linkRes, localAddr, n, onResolve)
+ return n.linkAddrCache.get(addr, linkRes, localAddr, onResolve)
}
func (n *NIC) neighbors() ([]NeighborEntry, tcpip.Error) {
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 2f719fbe5..3564202d8 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -111,8 +111,6 @@ type testIPv6EndpointStats struct{}
// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
func (*testIPv6EndpointStats) IsNetworkEndpointStats() {}
-var _ LinkAddressResolver = (*testIPv6Protocol)(nil)
-
// We use this instead of ipv6.protocol because the ipv6 package depends on
// the stack package which this test lives in, causing a cyclic dependency.
type testIPv6Protocol struct{}
@@ -169,24 +167,6 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo
return 0, false, false
}
-// LinkAddressProtocol implements LinkAddressResolver.
-func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
- return header.IPv6ProtocolNumber
-}
-
-// LinkAddressRequest implements LinkAddressResolver.
-func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error {
- return nil
-}
-
-// ResolveStaticAddress implements LinkAddressResolver.
-func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
- if header.IsV6MulticastAddress(addr) {
- return header.EthernetAddressFromMulticastIPv6Address(addr), true
- }
- return "", false
-}
-
func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
// When the NIC is disabled, the only field that matters is the stats field.
// This test is limited to stats counter checks.
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
index ebfd5eb45..504acc246 100644
--- a/pkg/tcpip/stack/nud_test.go
+++ b/pkg/tcpip/stack/nud_test.go
@@ -72,17 +72,17 @@ func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) {
}
// TestNUDConfigurationFailsForNotSupported tests to make sure we get a
-// NotSupported error if we attempt to retrieve NUD configurations when the
-// stack doesn't support NUD.
+// NotSupported error if we attempt to retrieve or set NUD configurations when
+// the stack doesn't support NUD.
//
// The stack will report to not support NUD if a neighbor cache for a given NIC
// is not allocated. The networking stack will only allocate neighbor caches if
-// a protocol providing link address resolution is specified (e.g. ARP, IPv6).
+// the NIC requires link resolution.
func TestNUDConfigurationFailsForNotSupported(t *testing.T) {
const nicID = 1
e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ e.LinkEPCapabilities &^= stack.CapabilityResolutionRequired
s := stack.New(stack.Options{
NUDConfigs: stack.DefaultNUDConfigurations(),
@@ -91,38 +91,21 @@ func TestNUDConfigurationFailsForNotSupported(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- _, err := s.NUDConfigurations(nicID)
- if _, ok := err.(*tcpip.ErrNotSupported); !ok {
- t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, &tcpip.ErrNotSupported{})
- }
-}
-// TestNUDConfigurationFailsForNotSupported tests to make sure we get a
-// NotSupported error if we attempt to set NUD configurations when the stack
-// doesn't support NUD.
-//
-// The stack will report to not support NUD if a neighbor cache for a given NIC
-// is not allocated. The networking stack will only allocate neighbor caches if
-// a protocol providing link address resolution is specified (e.g. ARP, IPv6).
-func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) {
- const nicID = 1
-
- e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- NUDConfigs: stack.DefaultNUDConfigurations(),
- UseNeighborCache: true,
+ t.Run("Get", func(t *testing.T) {
+ _, err := s.NUDConfigurations(nicID)
+ if _, ok := err.(*tcpip.ErrNotSupported); !ok {
+ t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, &tcpip.ErrNotSupported{})
+ }
})
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- config := stack.NUDConfigurations{}
- err := s.SetNUDConfigurations(nicID, config)
- if _, ok := err.(*tcpip.ErrNotSupported); !ok {
- t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, &tcpip.ErrNotSupported{})
- }
+ t.Run("Set", func(t *testing.T) {
+ config := stack.NUDConfigurations{}
+ err := s.SetNUDConfigurations(nicID, config)
+ if _, ok := err.(*tcpip.ErrNotSupported); !ok {
+ t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, &tcpip.ErrNotSupported{})
+ }
+ })
}
// TestDefaultNUDConfigurationIsValid verifies that calling
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 510da8689..64b5627e1 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -824,16 +824,12 @@ type InjectableLinkEndpoint interface {
InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error
}
-// A LinkAddressResolver is an extension to a NetworkProtocol that
-// can resolve link addresses.
+// A LinkAddressResolver handles link address resolution for a network protocol.
type LinkAddressResolver interface {
// LinkAddressRequest sends a request for the link address of the target
// address. The request is broadcasted on the local network if a remote link
// address is not provided.
- //
- // The request is sent from the passed network interface. If the interface
- // local address is unspecified, any interface local address may be used.
- LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) tcpip.Error
+ LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error
// ResolveStaticAddress attempts to resolve address without sending
// requests. It either resolves the name immediately or returns the
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 4ae0f2a1a..1c8ef6ed4 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -174,7 +174,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA
}
if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
- if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok {
+ if linkRes, ok := r.outgoingNIC.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
}
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 119c4c505..e720d676f 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -376,7 +376,6 @@ func (u *uniqueIDGenerator) UniqueID() uint64 {
type Stack struct {
transportProtocols map[tcpip.TransportProtocolNumber]*transportProtocolState
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
- linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
// rawFactory creates raw endpoints. If nil, raw endpoints are
// disabled. It is set during Stack creation and is immutable.
@@ -635,7 +634,6 @@ func New(opts Options) *Stack {
s := &Stack{
transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
nics: make(map[tcpip.NICID]*NIC),
cleanupEndpoints: make(map[TransportEndpoint]struct{}),
PortManager: ports.NewPortManager(),
@@ -666,9 +664,6 @@ func New(opts Options) *Stack {
for _, netProtoFactory := range opts.NetworkProtocols {
netProto := netProtoFactory(s)
s.networkProtocols[netProto.Number()] = netProto
- if r, ok := netProto.(LinkAddressResolver); ok {
- s.linkAddrResolvers[r.LinkAddressProtocol()] = r
- }
}
// Add specified transport protocols.
@@ -1561,18 +1556,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
return &tcpip.ErrUnknownNICID{}
}
- linkRes, ok := s.linkAddrResolvers[protocol]
- if !ok {
- return &tcpip.ErrNotSupported{}
- }
-
- if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok {
- onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true})
- return nil
- }
-
- _, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve)
- return err
+ return nic.getLinkAddress(addr, localAddr, protocol, onResolve)
}
// Neighbors returns all IP to MAC address associations.