summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network/arp
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network/arp')
-rw-r--r--pkg/tcpip/network/arp/arp.go59
1 files changed, 14 insertions, 45 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 7838cc753..5c79d6485 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -22,7 +22,6 @@ import (
"reflect"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -35,6 +34,8 @@ const (
ProtocolNumber = header.ARPProtocolNumber
)
+var _ stack.LinkAddressResolver = (*endpoint)(nil)
+
// ARP endpoints need to implement stack.NetworkEndpoint because the stack
// considers the layer above the link-layer a network layer; the only
// facility provided by the stack to deliver packets to a layer above
@@ -101,9 +102,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.ARPSize
}
-func (e *endpoint) Close() {
- e.protocol.forgetEndpoint(e.nic.ID())
-}
+func (*endpoint) Close() {}
func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error {
return &tcpip.ErrNotSupported{}
@@ -154,7 +153,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if e.nud == nil {
e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr)
} else {
- e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
+ e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e)
}
respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -221,19 +220,10 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
}
var _ stack.NetworkProtocol = (*protocol)(nil)
-var _ stack.LinkAddressResolver = (*protocol)(nil)
// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
type protocol struct {
stack *stack.Stack
-
- mu struct {
- sync.RWMutex
-
- // eps is keyed by NICID to allow protocol methods to retrieve the correct
- // endpoint depending on the NIC.
- eps map[tcpip.NICID]*endpoint
- }
}
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
@@ -257,43 +247,26 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
stackStats := p.stack.Stats()
e.stats.arp.init(&e.stats.localStats.ARP, &stackStats.ARP)
- p.mu.Lock()
- p.mu.eps[nic.ID()] = e
- p.mu.Unlock()
-
return e
}
-func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
- p.mu.Lock()
- defer p.mu.Unlock()
- delete(p.mu.eps, nicID)
-}
-
// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
-func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error {
- nicID := nic.ID()
-
- p.mu.Lock()
- netEP, ok := p.mu.eps[nicID]
- p.mu.Unlock()
- if !ok {
- return &tcpip.ErrNotConnected{}
- }
+func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
+ nicID := e.nic.ID()
- stats := netEP.stats.arp
+ stats := e.stats.arp
if len(remoteLinkAddr) == 0 {
remoteLinkAddr = header.EthernetBroadcastAddress
}
if len(localAddr) == 0 {
- addr, ok := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
+ addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
if !ok {
return &tcpip.ErrUnknownNICID{}
}
@@ -304,13 +277,13 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
localAddr = addr.Address
- } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ } else if e.protocol.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
stats.outgoingRequestBadLocalAddressErrors.Increment()
return &tcpip.ErrBadLocalAddress{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize,
+ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
})
h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
pkt.NetworkProtocolNumber = ProtocolNumber
@@ -318,14 +291,14 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
h.SetOp(header.ARPRequest)
// TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a
// link address.
- _ = copy(h.HardwareAddressSender(), nic.LinkAddress())
+ _ = copy(h.HardwareAddressSender(), e.nic.LinkAddress())
if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize {
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
}
if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize {
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
}
- if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
stats.outgoingRequestsDropped.Increment()
return err
}
@@ -334,7 +307,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
-func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast {
return header.EthernetBroadcastAddress, true
}
@@ -369,9 +342,5 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
return &protocol{
stack: s,
- mu: struct {
- sync.RWMutex
- eps map[tcpip.NICID]*endpoint
- }{eps: make(map[tcpip.NICID]*endpoint)},
}
}