diff options
Diffstat (limited to 'pkg/tcpip/network/arp')
-rw-r--r-- | pkg/tcpip/network/arp/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 59 | ||||
-rw-r--r-- | pkg/tcpip/network/arp/arp_test.go | 86 | ||||
-rw-r--r-- | pkg/tcpip/network/arp/stats_test.go | 42 |
4 files changed, 32 insertions, 156 deletions
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index c7ab876bf..933845269 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -10,7 +10,6 @@ go_library( ], visibility = ["//visibility:public"], deps = [ - "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7838cc753..5c79d6485 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -22,7 +22,6 @@ import ( "reflect" "sync/atomic" - "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -35,6 +34,8 @@ const ( ProtocolNumber = header.ARPProtocolNumber ) +var _ stack.LinkAddressResolver = (*endpoint)(nil) + // ARP endpoints need to implement stack.NetworkEndpoint because the stack // considers the layer above the link-layer a network layer; the only // facility provided by the stack to deliver packets to a layer above @@ -101,9 +102,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.ARPSize } -func (e *endpoint) Close() { - e.protocol.forgetEndpoint(e.nic.ID()) -} +func (*endpoint) Close() {} func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} @@ -154,7 +153,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if e.nud == nil { e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr) } else { - e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol) + e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e) } respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -221,19 +220,10 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { } var _ stack.NetworkProtocol = (*protocol)(nil) -var _ stack.LinkAddressResolver = (*protocol)(nil) // protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { stack *stack.Stack - - mu struct { - sync.RWMutex - - // eps is keyed by NICID to allow protocol methods to retrieve the correct - // endpoint depending on the NIC. - eps map[tcpip.NICID]*endpoint - } } func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } @@ -257,43 +247,26 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L stackStats := p.stack.Stats() e.stats.arp.init(&e.stats.localStats.ARP, &stackStats.ARP) - p.mu.Lock() - p.mu.eps[nic.ID()] = e - p.mu.Unlock() - return e } -func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { - p.mu.Lock() - defer p.mu.Unlock() - delete(p.mu.eps, nicID) -} - // LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. -func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error { - nicID := nic.ID() - - p.mu.Lock() - netEP, ok := p.mu.eps[nicID] - p.mu.Unlock() - if !ok { - return &tcpip.ErrNotConnected{} - } +func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { + nicID := e.nic.ID() - stats := netEP.stats.arp + stats := e.stats.arp if len(remoteLinkAddr) == 0 { remoteLinkAddr = header.EthernetBroadcastAddress } if len(localAddr) == 0 { - addr, ok := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) + addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) if !ok { return &tcpip.ErrUnknownNICID{} } @@ -304,13 +277,13 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } localAddr = addr.Address - } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + } else if e.protocol.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { stats.outgoingRequestBadLocalAddressErrors.Increment() return &tcpip.ErrBadLocalAddress{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize, + ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) pkt.NetworkProtocolNumber = ProtocolNumber @@ -318,14 +291,14 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot h.SetOp(header.ARPRequest) // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a // link address. - _ = copy(h.HardwareAddressSender(), nic.LinkAddress()) + _ = copy(h.HardwareAddressSender(), e.nic.LinkAddress()) if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize { panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize { panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } - if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { stats.outgoingRequestsDropped.Increment() return err } @@ -334,7 +307,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. -func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { return header.EthernetBroadcastAddress, true } @@ -369,9 +342,5 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu func NewProtocol(s *stack.Stack) stack.NetworkProtocol { return &protocol{ stack: s, - mu: struct { - sync.RWMutex - eps map[tcpip.NICID]*endpoint - }{eps: make(map[tcpip.NICID]*endpoint)}, } } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index b0f07aa44..d753a97af 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -530,52 +530,19 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { } } -var _ stack.NetworkInterface = (*testInterface)(nil) +var _ stack.LinkEndpoint = (*testLinkEndpoint)(nil) -type testInterface struct { +type testLinkEndpoint struct { stack.LinkEndpoint - nicID tcpip.NICID - writeErr tcpip.Error } -func (t *testInterface) ID() tcpip.NICID { - return t.nicID -} - -func (*testInterface) IsLoopback() bool { - return false -} - -func (*testInterface) Name() string { - return "" -} - -func (*testInterface) Enabled() bool { - return true -} - -func (*testInterface) Promiscuous() bool { - return false -} - -func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt) -} - -func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) -} - -func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if t.writeErr != nil { return t.writeErr } - var r stack.RouteInfo - r.NetProto = protocol - r.RemoteLinkAddress = remoteLinkAddr return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) } @@ -709,32 +676,31 @@ func TestLinkAddressRequest(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, }) - p := s.NetworkProtocolInstance(arp.ProtocolNumber) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") - } - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - if err := s.CreateNIC(nicID, linkEP); err != nil { + if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } + ep, err := s.GetNetworkEndpoint(nicID, arp.ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, arp.ProtocolNumber, err) + } + linkRes, ok := ep.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep) + } + if len(test.nicAddr) != 0 { if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) } } - // We pass a test network interface to LinkAddressRequest with the same - // NIC ID and link endpoint used by the NIC we created earlier so that we - // can mock a link address request and observe the packets sent to the - // link endpoint even though the stack uses the real NIC to validate the - // local address. - iface := testInterface{LinkEndpoint: linkEP, nicID: nicID, writeErr: test.linkErr} - err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface) - if diff := cmp.Diff(test.expectedErr, err); diff != "" { - t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff) + { + err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff) + } } if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent { @@ -782,19 +748,3 @@ func TestLinkAddressRequest(t *testing.T) { }) } } - -func TestLinkAddressRequestWithoutNIC(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, - }) - p := s.NetworkProtocolInstance(arp.ProtocolNumber) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") - } - - err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}) - if _, ok := err.(*tcpip.ErrNotConnected); !ok { - t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, &tcpip.ErrNotConnected{}) - } -} diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go index 036fdf739..d3b56c635 100644 --- a/pkg/tcpip/network/arp/stats_test.go +++ b/pkg/tcpip/network/arp/stats_test.go @@ -34,48 +34,6 @@ func (t *testInterface) ID() tcpip.NICID { return t.nicID } -func knownNICIDs(proto *protocol) []tcpip.NICID { - var nicIDs []tcpip.NICID - - for k := range proto.mu.eps { - nicIDs = append(nicIDs, k) - } - - return nicIDs -} - -func TestClearEndpointFromProtocolOnClose(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - nic := testInterface{nicID: 1} - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) - var nicIDs []tcpip.NICID - - proto.mu.Lock() - foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - - if !hasEndpointBeforeClose { - t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) - } - if foundEP != ep { - t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) - } - - ep.Close() - - proto.mu.Lock() - _, hasEndpointAfterClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - if hasEndpointAfterClose { - t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } -} - func TestMultiCounterStatsInitialization(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, |