diff options
Diffstat (limited to 'pkg/tcpip/network/arp')
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 59 |
1 files changed, 14 insertions, 45 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7838cc753..5c79d6485 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -22,7 +22,6 @@ import ( "reflect" "sync/atomic" - "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -35,6 +34,8 @@ const ( ProtocolNumber = header.ARPProtocolNumber ) +var _ stack.LinkAddressResolver = (*endpoint)(nil) + // ARP endpoints need to implement stack.NetworkEndpoint because the stack // considers the layer above the link-layer a network layer; the only // facility provided by the stack to deliver packets to a layer above @@ -101,9 +102,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.ARPSize } -func (e *endpoint) Close() { - e.protocol.forgetEndpoint(e.nic.ID()) -} +func (*endpoint) Close() {} func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} @@ -154,7 +153,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if e.nud == nil { e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr) } else { - e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol) + e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e) } respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -221,19 +220,10 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { } var _ stack.NetworkProtocol = (*protocol)(nil) -var _ stack.LinkAddressResolver = (*protocol)(nil) // protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { stack *stack.Stack - - mu struct { - sync.RWMutex - - // eps is keyed by NICID to allow protocol methods to retrieve the correct - // endpoint depending on the NIC. - eps map[tcpip.NICID]*endpoint - } } func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } @@ -257,43 +247,26 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L stackStats := p.stack.Stats() e.stats.arp.init(&e.stats.localStats.ARP, &stackStats.ARP) - p.mu.Lock() - p.mu.eps[nic.ID()] = e - p.mu.Unlock() - return e } -func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { - p.mu.Lock() - defer p.mu.Unlock() - delete(p.mu.eps, nicID) -} - // LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. -func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error { - nicID := nic.ID() - - p.mu.Lock() - netEP, ok := p.mu.eps[nicID] - p.mu.Unlock() - if !ok { - return &tcpip.ErrNotConnected{} - } +func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { + nicID := e.nic.ID() - stats := netEP.stats.arp + stats := e.stats.arp if len(remoteLinkAddr) == 0 { remoteLinkAddr = header.EthernetBroadcastAddress } if len(localAddr) == 0 { - addr, ok := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) + addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) if !ok { return &tcpip.ErrUnknownNICID{} } @@ -304,13 +277,13 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } localAddr = addr.Address - } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + } else if e.protocol.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { stats.outgoingRequestBadLocalAddressErrors.Increment() return &tcpip.ErrBadLocalAddress{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize, + ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) pkt.NetworkProtocolNumber = ProtocolNumber @@ -318,14 +291,14 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot h.SetOp(header.ARPRequest) // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a // link address. - _ = copy(h.HardwareAddressSender(), nic.LinkAddress()) + _ = copy(h.HardwareAddressSender(), e.nic.LinkAddress()) if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize { panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize { panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } - if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { stats.outgoingRequestsDropped.Increment() return err } @@ -334,7 +307,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. -func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { return header.EthernetBroadcastAddress, true } @@ -369,9 +342,5 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu func NewProtocol(s *stack.Stack) stack.NetworkProtocol { return &protocol{ stack: s, - mu: struct { - sync.RWMutex - eps map[tcpip.NICID]*endpoint - }{eps: make(map[tcpip.NICID]*endpoint)}, } } |