summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network/arp
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-01-30 11:35:35 -0800
committergVisor bot <gvisor-bot@google.com>2021-01-30 11:37:29 -0800
commit2d90bc54809766927e6028fac5b9f67cd2a13c3e (patch)
tree905a197ce947003d406b880fc1d096c60de0d342 /pkg/tcpip/network/arp
parent825c185dc56251bd330124ef773c6653e3887579 (diff)
Implement LinkAddressResolver on NetworkEndpoints
This removes the need to provide the link address request with the NIC the request is being performed on since the NetworkEndpoints already have a reference to the NIC. PiperOrigin-RevId: 354721940
Diffstat (limited to 'pkg/tcpip/network/arp')
-rw-r--r--pkg/tcpip/network/arp/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp.go59
-rw-r--r--pkg/tcpip/network/arp/arp_test.go86
-rw-r--r--pkg/tcpip/network/arp/stats_test.go42
4 files changed, 32 insertions, 156 deletions
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index c7ab876bf..933845269 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -10,7 +10,6 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
- "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 7838cc753..5c79d6485 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -22,7 +22,6 @@ import (
"reflect"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -35,6 +34,8 @@ const (
ProtocolNumber = header.ARPProtocolNumber
)
+var _ stack.LinkAddressResolver = (*endpoint)(nil)
+
// ARP endpoints need to implement stack.NetworkEndpoint because the stack
// considers the layer above the link-layer a network layer; the only
// facility provided by the stack to deliver packets to a layer above
@@ -101,9 +102,7 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.ARPSize
}
-func (e *endpoint) Close() {
- e.protocol.forgetEndpoint(e.nic.ID())
-}
+func (*endpoint) Close() {}
func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error {
return &tcpip.ErrNotSupported{}
@@ -154,7 +153,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if e.nud == nil {
e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr)
} else {
- e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
+ e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e)
}
respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -221,19 +220,10 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
}
var _ stack.NetworkProtocol = (*protocol)(nil)
-var _ stack.LinkAddressResolver = (*protocol)(nil)
// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
type protocol struct {
stack *stack.Stack
-
- mu struct {
- sync.RWMutex
-
- // eps is keyed by NICID to allow protocol methods to retrieve the correct
- // endpoint depending on the NIC.
- eps map[tcpip.NICID]*endpoint
- }
}
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
@@ -257,43 +247,26 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
stackStats := p.stack.Stats()
e.stats.arp.init(&e.stats.localStats.ARP, &stackStats.ARP)
- p.mu.Lock()
- p.mu.eps[nic.ID()] = e
- p.mu.Unlock()
-
return e
}
-func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
- p.mu.Lock()
- defer p.mu.Unlock()
- delete(p.mu.eps, nicID)
-}
-
// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
-func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error {
- nicID := nic.ID()
-
- p.mu.Lock()
- netEP, ok := p.mu.eps[nicID]
- p.mu.Unlock()
- if !ok {
- return &tcpip.ErrNotConnected{}
- }
+func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
+ nicID := e.nic.ID()
- stats := netEP.stats.arp
+ stats := e.stats.arp
if len(remoteLinkAddr) == 0 {
remoteLinkAddr = header.EthernetBroadcastAddress
}
if len(localAddr) == 0 {
- addr, ok := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
+ addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
if !ok {
return &tcpip.ErrUnknownNICID{}
}
@@ -304,13 +277,13 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
localAddr = addr.Address
- } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ } else if e.protocol.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
stats.outgoingRequestBadLocalAddressErrors.Increment()
return &tcpip.ErrBadLocalAddress{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize,
+ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
})
h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
pkt.NetworkProtocolNumber = ProtocolNumber
@@ -318,14 +291,14 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
h.SetOp(header.ARPRequest)
// TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a
// link address.
- _ = copy(h.HardwareAddressSender(), nic.LinkAddress())
+ _ = copy(h.HardwareAddressSender(), e.nic.LinkAddress())
if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize {
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
}
if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize {
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
}
- if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
stats.outgoingRequestsDropped.Increment()
return err
}
@@ -334,7 +307,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
-func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast {
return header.EthernetBroadcastAddress, true
}
@@ -369,9 +342,5 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
return &protocol{
stack: s,
- mu: struct {
- sync.RWMutex
- eps map[tcpip.NICID]*endpoint
- }{eps: make(map[tcpip.NICID]*endpoint)},
}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index b0f07aa44..d753a97af 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -530,52 +530,19 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
}
}
-var _ stack.NetworkInterface = (*testInterface)(nil)
+var _ stack.LinkEndpoint = (*testLinkEndpoint)(nil)
-type testInterface struct {
+type testLinkEndpoint struct {
stack.LinkEndpoint
- nicID tcpip.NICID
-
writeErr tcpip.Error
}
-func (t *testInterface) ID() tcpip.NICID {
- return t.nicID
-}
-
-func (*testInterface) IsLoopback() bool {
- return false
-}
-
-func (*testInterface) Name() string {
- return ""
-}
-
-func (*testInterface) Enabled() bool {
- return true
-}
-
-func (*testInterface) Promiscuous() bool {
- return false
-}
-
-func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt)
-}
-
-func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol)
-}
-
-func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if t.writeErr != nil {
return t.writeErr
}
- var r stack.RouteInfo
- r.NetProto = protocol
- r.RemoteLinkAddress = remoteLinkAddr
return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt)
}
@@ -709,32 +676,31 @@ func TestLinkAddressRequest(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
})
- p := s.NetworkProtocolInstance(arp.ProtocolNumber)
- linkRes, ok := p.(stack.LinkAddressResolver)
- if !ok {
- t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
- }
-
linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
- if err := s.CreateNIC(nicID, linkEP); err != nil {
+ if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
+ ep, err := s.GetNetworkEndpoint(nicID, arp.ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, arp.ProtocolNumber, err)
+ }
+ linkRes, ok := ep.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep)
+ }
+
if len(test.nicAddr) != 0 {
if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil {
t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err)
}
}
- // We pass a test network interface to LinkAddressRequest with the same
- // NIC ID and link endpoint used by the NIC we created earlier so that we
- // can mock a link address request and observe the packets sent to the
- // link endpoint even though the stack uses the real NIC to validate the
- // local address.
- iface := testInterface{LinkEndpoint: linkEP, nicID: nicID, writeErr: test.linkErr}
- err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface)
- if diff := cmp.Diff(test.expectedErr, err); diff != "" {
- t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff)
+ {
+ err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr)
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff)
+ }
}
if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent {
@@ -782,19 +748,3 @@ func TestLinkAddressRequest(t *testing.T) {
})
}
}
-
-func TestLinkAddressRequestWithoutNIC(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
- })
- p := s.NetworkProtocolInstance(arp.ProtocolNumber)
- linkRes, ok := p.(stack.LinkAddressResolver)
- if !ok {
- t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
- }
-
- err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID})
- if _, ok := err.(*tcpip.ErrNotConnected); !ok {
- t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, &tcpip.ErrNotConnected{})
- }
-}
diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go
index 036fdf739..d3b56c635 100644
--- a/pkg/tcpip/network/arp/stats_test.go
+++ b/pkg/tcpip/network/arp/stats_test.go
@@ -34,48 +34,6 @@ func (t *testInterface) ID() tcpip.NICID {
return t.nicID
}
-func knownNICIDs(proto *protocol) []tcpip.NICID {
- var nicIDs []tcpip.NICID
-
- for k := range proto.mu.eps {
- nicIDs = append(nicIDs, k)
- }
-
- return nicIDs
-}
-
-func TestClearEndpointFromProtocolOnClose(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- nic := testInterface{nicID: 1}
- ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
- var nicIDs []tcpip.NICID
-
- proto.mu.Lock()
- foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
-
- if !hasEndpointBeforeClose {
- t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs)
- }
- if foundEP != ep {
- t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID())
- }
-
- ep.Close()
-
- proto.mu.Lock()
- _, hasEndpointAfterClose := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
- if hasEndpointAfterClose {
- t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
- }
-}
-
func TestMultiCounterStatsInitialization(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},