summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network/arp
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network/arp')
-rw-r--r--pkg/tcpip/network/arp/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp.go59
-rw-r--r--pkg/tcpip/network/arp/arp_test.go86
-rw-r--r--pkg/tcpip/network/arp/stats_test.go42
4 files changed, 32 insertions, 156 deletions
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index c7ab876bf..933845269 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -10,7 +10,6 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
- "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 7838cc753..5c79d6485 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -22,7 +22,6 @@ import (
"reflect"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -35,6 +34,8 @@ const (
ProtocolNumber = header.ARPProtocolNumber
)
+var _ stack.LinkAddressResolver = (*endpoint)(nil)
+
// ARP endpoints need to implement stack.NetworkEndpoint because the stack
// considers the layer above the link-layer a network layer; the only
// facility provided by the stack to deliver packets to a layer above
@@ -101,9 +102,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.ARPSize
}
-func (e *endpoint) Close() {
- e.protocol.forgetEndpoint(e.nic.ID())
-}
+func (*endpoint) Close() {}
func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error {
return &tcpip.ErrNotSupported{}
@@ -154,7 +153,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if e.nud == nil {
e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr)
} else {
- e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
+ e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e)
}
respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -221,19 +220,10 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
}
var _ stack.NetworkProtocol = (*protocol)(nil)
-var _ stack.LinkAddressResolver = (*protocol)(nil)
// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
type protocol struct {
stack *stack.Stack
-
- mu struct {
- sync.RWMutex
-
- // eps is keyed by NICID to allow protocol methods to retrieve the correct
- // endpoint depending on the NIC.
- eps map[tcpip.NICID]*endpoint
- }
}
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
@@ -257,43 +247,26 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
stackStats := p.stack.Stats()
e.stats.arp.init(&e.stats.localStats.ARP, &stackStats.ARP)
- p.mu.Lock()
- p.mu.eps[nic.ID()] = e
- p.mu.Unlock()
-
return e
}
-func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
- p.mu.Lock()
- defer p.mu.Unlock()
- delete(p.mu.eps, nicID)
-}
-
// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
-func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error {
- nicID := nic.ID()
-
- p.mu.Lock()
- netEP, ok := p.mu.eps[nicID]
- p.mu.Unlock()
- if !ok {
- return &tcpip.ErrNotConnected{}
- }
+func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
+ nicID := e.nic.ID()
- stats := netEP.stats.arp
+ stats := e.stats.arp
if len(remoteLinkAddr) == 0 {
remoteLinkAddr = header.EthernetBroadcastAddress
}
if len(localAddr) == 0 {
- addr, ok := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
+ addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
if !ok {
return &tcpip.ErrUnknownNICID{}
}
@@ -304,13 +277,13 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
localAddr = addr.Address
- } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ } else if e.protocol.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
stats.outgoingRequestBadLocalAddressErrors.Increment()
return &tcpip.ErrBadLocalAddress{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize,
+ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
})
h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
pkt.NetworkProtocolNumber = ProtocolNumber
@@ -318,14 +291,14 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
h.SetOp(header.ARPRequest)
// TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a
// link address.
- _ = copy(h.HardwareAddressSender(), nic.LinkAddress())
+ _ = copy(h.HardwareAddressSender(), e.nic.LinkAddress())
if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize {
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
}
if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize {
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
}
- if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
stats.outgoingRequestsDropped.Increment()
return err
}
@@ -334,7 +307,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
-func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast {
return header.EthernetBroadcastAddress, true
}
@@ -369,9 +342,5 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
return &protocol{
stack: s,
- mu: struct {
- sync.RWMutex
- eps map[tcpip.NICID]*endpoint
- }{eps: make(map[tcpip.NICID]*endpoint)},
}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index b0f07aa44..d753a97af 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -530,52 +530,19 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
}
}
-var _ stack.NetworkInterface = (*testInterface)(nil)
+var _ stack.LinkEndpoint = (*testLinkEndpoint)(nil)
-type testInterface struct {
+type testLinkEndpoint struct {
stack.LinkEndpoint
- nicID tcpip.NICID
-
writeErr tcpip.Error
}
-func (t *testInterface) ID() tcpip.NICID {
- return t.nicID
-}
-
-func (*testInterface) IsLoopback() bool {
- return false
-}
-
-func (*testInterface) Name() string {
- return ""
-}
-
-func (*testInterface) Enabled() bool {
- return true
-}
-
-func (*testInterface) Promiscuous() bool {
- return false
-}
-
-func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt)
-}
-
-func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol)
-}
-
-func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if t.writeErr != nil {
return t.writeErr
}
- var r stack.RouteInfo
- r.NetProto = protocol
- r.RemoteLinkAddress = remoteLinkAddr
return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt)
}
@@ -709,32 +676,31 @@ func TestLinkAddressRequest(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
})
- p := s.NetworkProtocolInstance(arp.ProtocolNumber)
- linkRes, ok := p.(stack.LinkAddressResolver)
- if !ok {
- t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
- }
-
linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
- if err := s.CreateNIC(nicID, linkEP); err != nil {
+ if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
+ ep, err := s.GetNetworkEndpoint(nicID, arp.ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, arp.ProtocolNumber, err)
+ }
+ linkRes, ok := ep.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep)
+ }
+
if len(test.nicAddr) != 0 {
if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil {
t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err)
}
}
- // We pass a test network interface to LinkAddressRequest with the same
- // NIC ID and link endpoint used by the NIC we created earlier so that we
- // can mock a link address request and observe the packets sent to the
- // link endpoint even though the stack uses the real NIC to validate the
- // local address.
- iface := testInterface{LinkEndpoint: linkEP, nicID: nicID, writeErr: test.linkErr}
- err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface)
- if diff := cmp.Diff(test.expectedErr, err); diff != "" {
- t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff)
+ {
+ err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr)
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff)
+ }
}
if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent {
@@ -782,19 +748,3 @@ func TestLinkAddressRequest(t *testing.T) {
})
}
}
-
-func TestLinkAddressRequestWithoutNIC(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
- })
- p := s.NetworkProtocolInstance(arp.ProtocolNumber)
- linkRes, ok := p.(stack.LinkAddressResolver)
- if !ok {
- t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
- }
-
- err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID})
- if _, ok := err.(*tcpip.ErrNotConnected); !ok {
- t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, &tcpip.ErrNotConnected{})
- }
-}
diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go
index 036fdf739..d3b56c635 100644
--- a/pkg/tcpip/network/arp/stats_test.go
+++ b/pkg/tcpip/network/arp/stats_test.go
@@ -34,48 +34,6 @@ func (t *testInterface) ID() tcpip.NICID {
return t.nicID
}
-func knownNICIDs(proto *protocol) []tcpip.NICID {
- var nicIDs []tcpip.NICID
-
- for k := range proto.mu.eps {
- nicIDs = append(nicIDs, k)
- }
-
- return nicIDs
-}
-
-func TestClearEndpointFromProtocolOnClose(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- nic := testInterface{nicID: 1}
- ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
- var nicIDs []tcpip.NICID
-
- proto.mu.Lock()
- foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
-
- if !hasEndpointBeforeClose {
- t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs)
- }
- if foundEP != ep {
- t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID())
- }
-
- ep.Close()
-
- proto.mu.Lock()
- _, hasEndpointAfterClose := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
- if hasEndpointAfterClose {
- t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
- }
-}
-
func TestMultiCounterStatsInitialization(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},