diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-01-31 10:01:30 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-01-31 10:03:46 -0800 |
commit | daeb06d2cbf5509bd53dc67138504e51d0fcfae8 (patch) | |
tree | aaa3ef4bb641a5aa662342609308e1575777991f | |
parent | 8dda226542d7703ed7cb6df78396d76dff38be45 (diff) |
Hide neighbor table kind from NetworkEndpoint
The network endpoint should not need to have logic to handle different
kinds of neighbor tables. Network endpoints can let the NIC know about
differnt neighbor discovery messages and let the NIC decide which table
to update.
This allows us to remove the LinkAddressCache interface.
PiperOrigin-RevId: 354812584
24 files changed, 342 insertions, 346 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 5c79d6485..5fd4c5574 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -50,10 +50,8 @@ type endpoint struct { // Must be accessed using atomic operations. enabled uint32 - nic stack.NetworkInterface - linkAddrCache stack.LinkAddressCache - nud stack.NUDHandler - stats sharedStats + nic stack.NetworkInterface + stats sharedStats } func (e *endpoint) Enable() tcpip.Error { @@ -150,11 +148,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { remoteAddr := tcpip.Address(h.ProtocolAddressSender()) remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - if e.nud == nil { - e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr) - } else { - e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e) - } + e.nic.HandleNeighborProbe(remoteAddr, remoteLinkAddr, e) respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, @@ -194,14 +188,9 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - if e.nud == nil { - e.linkAddrCache.AddLinkAddress(addr, linkAddr) - return - } - // The solicited, override, and isRouter flags are not available for ARP; // they are only available for IPv6 Neighbor Advertisements. - e.nud.HandleConfirmation(addr, linkAddr, stack.ReachabilityConfirmationFlags{ + e.nic.HandleNeighborConfirmation(addr, linkAddr, stack.ReachabilityConfirmationFlags{ // Solicited and unsolicited (also referred to as gratuitous) ARP Replies // are handled equivalently to a solicited Neighbor Advertisement. Solicited: true, @@ -234,12 +223,10 @@ func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) { return "", "" } -func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ - protocol: p, - nic: nic, - linkAddrCache: linkAddrCache, - nud: nud, + protocol: p, + nic: nic, } tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem()) diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go index d3b56c635..65c708ac4 100644 --- a/pkg/tcpip/network/arp/stats_test.go +++ b/pkg/tcpip/network/arp/stats_test.go @@ -40,7 +40,7 @@ func TestMultiCounterStatsInitialization(t *testing.T) { }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) var nic testInterface - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&nic, nil).(*endpoint) // At this point, the Stack's stats and the NetworkEndpoint's stats are // expected to be bound by a MultiCounterStat. refStack := s.Stats() diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 47cce79bb..291330e8e 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -311,6 +311,12 @@ func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.N return &tcpip.ErrNotSupported{} } +func (*testInterface) HandleNeighborProbe(tcpip.Address, tcpip.LinkAddress, stack.LinkAddressResolver) { +} + +func (*testInterface) HandleNeighborConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) { +} + func TestSourceAddressValidation(t *testing.T) { rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize @@ -465,7 +471,7 @@ func TestEnableWhenNICDisabled(t *testing.T) { // We pass nil for all parameters except the NetworkInterface and Stack // since Enable only depends on these. - ep := p.NewEndpoint(&nic, nil, nil, nil) + ep := p.NewEndpoint(&nic, nil) // The endpoint should initially be disabled, regardless the NIC's enabled // status. @@ -527,7 +533,7 @@ func TestIPv4Send(t *testing.T) { v4: true, }, } - ep := proto.NewEndpoint(&nic, nil, nil, nil) + ep := proto.NewEndpoint(&nic, nil) defer ep.Close() // Allocate and initialize the payload view. @@ -661,7 +667,7 @@ func TestReceive(t *testing.T) { v4: test.v4, }, } - ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, nil, nil, &nic.testObject) + ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -722,7 +728,7 @@ func TestIPv4ReceiveControl(t *testing.T) { t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) + ep := proto.NewEndpoint(&nic, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -811,7 +817,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { v4: true, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) + ep := proto.NewEndpoint(&nic, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -906,7 +912,7 @@ func TestIPv6Send(t *testing.T) { t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, nil) + ep := proto.NewEndpoint(&nic, nil) defer ep.Close() if err := ep.Enable(); err != nil { @@ -979,7 +985,7 @@ func TestIPv6ReceiveControl(t *testing.T) { t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) + ep := proto.NewEndpoint(&nic, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 04c6a6708..e146844c2 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -105,7 +105,7 @@ func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ nic: nic, dispatcher: dispatcher, diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go index b28e7dcde..fbbc6e69c 100644 --- a/pkg/tcpip/network/ipv4/stats_test.go +++ b/pkg/tcpip/network/ipv4/stats_test.go @@ -50,7 +50,7 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) { }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) nic := testInterface{nicID: 1} - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&nic, nil).(*endpoint) var nicIDs []tcpip.NICID proto.mu.Lock() @@ -82,7 +82,7 @@ func TestMultiCounterStatsInitialization(t *testing.T) { }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) var nic testInterface - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&nic, nil).(*endpoint) // At this point, the Stack's stats and the NetworkEndpoint's stats are // expected to be bound by a MultiCounterStat. refStack := s.Stats() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 8db2454d3..bdc88fe5d 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -289,10 +289,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } else if unspecifiedSource { received.invalid.Increment() return - } else if e.nud != nil { - e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e) } else { - e.linkAddrCache.AddLinkAddress(srcAddr, sourceLinkAddr) + e.nic.HandleNeighborProbe(srcAddr, sourceLinkAddr, e) } // As per RFC 4861 section 7.1.1: @@ -458,14 +456,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // If the NA message has the target link layer option, update the link // address cache with the link address for the target of the message. - if e.nud == nil { - if len(targetLinkAddr) != 0 { - e.linkAddrCache.AddLinkAddress(targetAddr, targetLinkAddr) - } - return - } - - e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ + e.nic.HandleNeighborConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ Solicited: na.SolicitedFlag(), Override: na.OverrideFlag(), IsRouter: na.RouterFlag(), @@ -575,11 +566,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { return } - if e.nud != nil { - // A RS with a specified source IP address modifies the NUD state - // machine in the same way a reachability probe would. - e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e) - } + // A RS with a specified source IP address modifies the NUD state + // machine in the same way a reachability probe would. + e.nic.HandleNeighborProbe(srcAddr, sourceLinkAddr, e) } case header.ICMPv6RouterAdvert: @@ -627,8 +616,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // If the RA has the source link layer option, update the link address // cache with the link address for the advertised router. - if len(sourceLinkAddr) != 0 && e.nud != nil { - e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e) + if len(sourceLinkAddr) != 0 { + e.nic.HandleNeighborProbe(routerAddr, sourceLinkAddr, e) } e.mu.Lock() diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index a5c88444e..755293377 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -93,35 +93,14 @@ func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *st return stack.TransportPacketHandled } -var _ stack.LinkAddressCache = (*stubLinkAddressCache)(nil) - -type stubLinkAddressCache struct{} - -func (*stubLinkAddressCache) AddLinkAddress(tcpip.Address, tcpip.LinkAddress) {} - -type stubNUDHandler struct { - probeCount int - confirmationCount int -} - -var _ stack.NUDHandler = (*stubNUDHandler)(nil) - -func (s *stubNUDHandler) HandleProbe(tcpip.Address, tcpip.NetworkProtocolNumber, tcpip.LinkAddress, stack.LinkAddressResolver) { - s.probeCount++ -} - -func (s *stubNUDHandler) HandleConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) { - s.confirmationCount++ -} - -func (*stubNUDHandler) HandleUpperLevelConfirmation(tcpip.Address) { -} - var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { stack.LinkEndpoint + probeCount int + confirmationCount int + nicID tcpip.NICID } @@ -160,6 +139,14 @@ func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gs return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) } +func (t *testInterface) HandleNeighborProbe(tcpip.Address, tcpip.LinkAddress, stack.LinkAddressResolver) { + t.probeCount++ +} + +func (t *testInterface) HandleNeighborConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) { + t.confirmationCount++ +} + func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -202,7 +189,7 @@ func TestICMPCounts(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) + ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { @@ -360,7 +347,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{}) + ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { @@ -1801,8 +1788,9 @@ func TestCallsToNeighborCache(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - nudHandler := &stubNUDHandler{} - ep := netProto.NewEndpoint(&testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) + + testInterface := testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)} + ep := netProto.NewEndpoint(&testInterface, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { @@ -1837,11 +1825,11 @@ func TestCallsToNeighborCache(t *testing.T) { ep.HandlePacket(pkt) // Confirm the endpoint calls the correct NUDHandler method. - if nudHandler.probeCount != test.wantProbeCount { - t.Errorf("got nudHandler.probeCount = %d, want = %d", nudHandler.probeCount, test.wantProbeCount) + if testInterface.probeCount != test.wantProbeCount { + t.Errorf("got testInterface.probeCount = %d, want = %d", testInterface.probeCount, test.wantProbeCount) } - if nudHandler.confirmationCount != test.wantConfirmationCount { - t.Errorf("got nudHandler.confirmationCount = %d, want = %d", nudHandler.confirmationCount, test.wantConfirmationCount) + if testInterface.confirmationCount != test.wantConfirmationCount { + t.Errorf("got testInterface.confirmationCount = %d, want = %d", testInterface.confirmationCount, test.wantConfirmationCount) } }) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index b55a35525..e56eb5796 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -173,13 +173,11 @@ var _ stack.NDPEndpoint = (*endpoint)(nil) var _ NDPEndpoint = (*endpoint)(nil) type endpoint struct { - nic stack.NetworkInterface - linkAddrCache stack.LinkAddressCache - nud stack.NUDHandler - dispatcher stack.TransportDispatcher - protocol *protocol - stack *stack.Stack - stats sharedStats + nic stack.NetworkInterface + dispatcher stack.TransportDispatcher + protocol *protocol + stack *stack.Stack + stats sharedStats // enabled is set to 1 when the endpoint is enabled and 0 when it is // disabled. @@ -1733,13 +1731,11 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ - nic: nic, - linkAddrCache: linkAddrCache, - nud: nud, - dispatcher: dispatcher, - protocol: p, + nic: nic, + dispatcher: dispatcher, + protocol: p, } e.mu.Lock() e.mu.addressableEndpointState.Init(e) diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 8248052a3..1c6c37c91 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -2599,7 +2599,7 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) { }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) var nic testInterface - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&nic, nil).(*endpoint) var nicIDs []tcpip.NICID proto.mu.Lock() @@ -3075,7 +3075,7 @@ func TestMultiCounterStatsInitialization(t *testing.T) { }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) var nic testInterface - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&nic, nil).(*endpoint) // At this point, the Stack's stats and the NetworkEndpoint's stats are // supposed to be bound. refStack := s.Stats() diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 1d38b8b05..4cc81e6cc 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -63,7 +63,7 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) + ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) if err := ep.Enable(); err != nil { t.Fatalf("ep.Enable(): %s", err) } @@ -199,6 +199,7 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) e := channel.New(0, 1280, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -760,6 +761,7 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) e := channel.New(0, 1280, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 1e4ddf0d5..704812641 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -161,10 +161,9 @@ var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) type fwdTestNetworkProtocol struct { stack *Stack - addrCache *linkAddrCache - neigh *neighborCache + neighborTable neighborTable addrResolveDelay time.Duration - onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) + onLinkAddressResolved func(neighborTable, tcpip.Address, tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) mu struct { @@ -197,7 +196,7 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } -func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint { +func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, dispatcher TransportDispatcher) NetworkEndpoint { e := &fwdTestNetworkEndpoint{ nic: nic, proto: f, @@ -222,7 +221,7 @@ func (*fwdTestNetworkProtocol) Wait() {} func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { if fn := f.proto.onLinkAddressResolved; fn != nil { time.AfterFunc(f.proto.addrResolveDelay, func() { - fn(f.proto.addrCache, f.proto.neigh, addr, remoteLinkAddr) + fn(f.proto.neighborTable, addr, remoteLinkAddr) }) } return nil @@ -401,12 +400,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC if !ok { t.Fatal("NIC 2 does not exist") } - if useNeighborCache { - // Control the neighbor cache for NIC 2. - proto.neigh = nic.neigh - } else { - proto.addrCache = nic.linkAddrCache - } + proto.neighborTable = nic.neighborTable // Route all packets to NIC 2. { @@ -482,43 +476,35 @@ func TestForwardingWithFakeResolver(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any address will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any address will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -573,7 +559,7 @@ func TestForwardingWithNoResolver(t *testing.T) { func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { proto := &fwdTestNetworkProtocol{ addrResolveDelay: 50 * time.Millisecond, - onLinkAddressResolved: func(*linkAddrCache, *neighborCache, tcpip.Address, tcpip.LinkAddress) { + onLinkAddressResolved: func(neighborTable, tcpip.Address, tcpip.LinkAddress) { // Don't resolve the link address. }, } @@ -606,49 +592,38 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - cache.AddLinkAddress(addr, "c") - } - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Only packets to address 3 will be resolved to the // link address "c". if addr == "\x03" { - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) } }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) // Inject an inbound packet to address 4 on NIC 1. This packet should // not be forwarded. @@ -693,43 +668,35 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any packets will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) // Inject two inbound packets to address 3 on NIC 1. for i := 0; i < 2; i++ { @@ -769,43 +736,35 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any packets will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) for i := 0; i < maxPendingPacketsPerResolution+5; i++ { // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. @@ -864,38 +823,31 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any packets will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) for i := 0; i < maxPendingResolutions+5; i++ { // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index cd2bb3417..4504db752 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -24,8 +24,6 @@ import ( const linkAddrCacheSize = 512 // max cache entries -var _ LinkAddressCache = (*linkAddrCache)(nil) - // linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses. // // The entries are stored in a ring buffer, oldest entry replaced first. @@ -140,7 +138,7 @@ func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { } // add adds a k -> v mapping to the cache. -func (c *linkAddrCache) AddLinkAddress(k tcpip.Address, v tcpip.LinkAddress) { +func (c *linkAddrCache) add(k tcpip.Address, v tcpip.LinkAddress) { // Calculate expiration time before acquiring the lock, since expiration is // relative to the time when information was learned, rather than when it // happened to be inserted into the cache. @@ -290,3 +288,67 @@ func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resol c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize) return c } + +var _ neighborTable = (*linkAddrCache)(nil) + +func (*linkAddrCache) neighbors() ([]NeighborEntry, tcpip.Error) { + return nil, &tcpip.ErrNotSupported{} +} + +func (c *linkAddrCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) { + c.add(addr, linkAddr) +} + +func (*linkAddrCache) remove(addr tcpip.Address) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + +func (*linkAddrCache) removeAll() tcpip.Error { + return &tcpip.ErrNotSupported{} +} + +func (c *linkAddrCache) handleProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress, _ LinkAddressResolver) { + if len(linkAddr) != 0 { + // NUD allows probes without a link address but linkAddrCache + // is a simple neighbor table which does not implement NUD. + // + // As per RFC 4861 section 4.3, + // + // Source link-layer address + // The link-layer address for the sender. MUST NOT be + // included when the source IP address is the + // unspecified address. Otherwise, on link layers + // that have addresses this option MUST be included in + // multicast solicitations and SHOULD be included in + // unicast solicitations. + c.add(addr, linkAddr) + } +} + +func (c *linkAddrCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { + if len(linkAddr) != 0 { + // NUD allows confirmations without a link address but linkAddrCache + // is a simple neighbor table which does not implement NUD. + // + // As per RFC 4861 section 4.4, + // + // Target link-layer address + // The link-layer address for the target, i.e., the + // sender of the advertisement. This option MUST be + // included on link layers that have addresses when + // responding to multicast solicitations. When + // responding to a unicast Neighbor Solicitation this + // option SHOULD be included. + c.add(addr, linkAddr) + } +} + +func (c *linkAddrCache) handleUpperLevelConfirmation(tcpip.Address) {} + +func (*linkAddrCache) nudConfig() (NUDConfigurations, tcpip.Error) { + return NUDConfigurations{}, &tcpip.ErrNotSupported{} +} + +func (*linkAddrCache) setNUDConfig(NUDConfigurations) tcpip.Error { + return &tcpip.ErrNotSupported{} +} diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 40017c8b6..4df6f9265 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -60,7 +60,7 @@ func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { for _, ta := range testAddrs { if ta.addr == addr { - r.cache.AddLinkAddress(ta.addr, ta.linkAddr) + r.cache.add(ta.addr, ta.linkAddr) break } } @@ -103,7 +103,7 @@ func TestCacheOverflow(t *testing.T) { c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) for i := len(testAddrs) - 1; i >= 0; i-- { e := testAddrs[i] - c.AddLinkAddress(e.addr, e.linkAddr) + c.add(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil) if err != nil { t.Errorf("insert %d, c.get(%s, nil, '', nil): %s", i, e.addr, err) @@ -143,7 +143,7 @@ func TestCacheConcurrent(t *testing.T) { wg.Add(1) go func() { for _, e := range testAddrs { - c.AddLinkAddress(e.addr, e.linkAddr) + c.add(e.addr, e.linkAddr) } wg.Done() }() @@ -175,7 +175,7 @@ func TestCacheAgeLimit(t *testing.T) { linkRes := &testLinkAddressResolver{cache: c} e := testAddrs[0] - c.AddLinkAddress(e.addr, e.linkAddr) + c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) _, _, err := c.get(e.addr, linkRes, "", nil) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { @@ -187,7 +187,7 @@ func TestCacheReplace(t *testing.T) { c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) e := testAddrs[0] l2 := e.linkAddr + "2" - c.AddLinkAddress(e.addr, e.linkAddr) + c.add(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil) if err != nil { t.Errorf("c.get(%s, nil, '', nil): %s", e.addr, err) @@ -196,7 +196,7 @@ func TestCacheReplace(t *testing.T) { t.Errorf("got c.get(%s, nil, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } - c.AddLinkAddress(e.addr, l2) + c.add(e.addr, l2) got, _, err = c.get(e.addr, nil, "", nil) if err != nil { t.Errorf("c.get(%s, nil, '', nil): %s", e.addr, err) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 64383bc7c..c13be137e 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -2796,14 +2796,8 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN NIC: nicID, }}) - if useNeighborCache { - if err := s.AddStaticNeighbor(nicID, llAddr3, linkAddr3); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) - } - } else { - if err := s.AddLinkAddress(nicID, llAddr3, linkAddr3); err != nil { - t.Fatalf("s.AddLinkAddress(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) - } + if err := s.AddStaticNeighbor(nicID, llAddr3, linkAddr3); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) } return ndpDisp, e, s } diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 88a3ff776..64b8046f5 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -42,8 +42,6 @@ type NeighborStats struct { // 2. Static entries are explicitly added by a user and have no expiration. // Their state is always Static. The amount of static entries stored in the // cache is unbounded. -// -// neighborCache implements NUDHandler. type neighborCache struct { nic *NIC state *NUDState @@ -62,8 +60,6 @@ type neighborCache struct { } } -var _ NUDHandler = (*neighborCache)(nil) - // getOrCreateEntry retrieves a cache entry associated with addr. The // returned entry is always refreshed in the cache (it is reachable via the // map, and its place is bumped in LRU). @@ -263,27 +259,45 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { n.state.SetConfig(config) } -// HandleProbe implements NUDHandler.HandleProbe by following the logic defined -// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled -// by the caller. -func (n *neighborCache) HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { +var _ neighborTable = (*neighborCache)(nil) + +func (n *neighborCache) neighbors() ([]NeighborEntry, tcpip.Error) { + return n.entries(), nil +} + +func (n *neighborCache) get(addr tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { + entry, ch, err := n.entry(addr, localAddr, linkRes, onResolve) + return entry.LinkAddr, ch, err +} + +func (n *neighborCache) remove(addr tcpip.Address) tcpip.Error { + if !n.removeEntry(addr) { + return &tcpip.ErrBadAddress{} + } + + return nil +} + +func (n *neighborCache) removeAll() tcpip.Error { + n.clear() + return nil +} + +// handleProbe handles a neighbor probe as defined by RFC 4861 section 7.2.3. +// +// Validation of the probe is expected to be handled by the caller. +func (n *neighborCache) handleProbe(remoteAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() entry.handleProbeLocked(remoteLinkAddr) entry.mu.Unlock() } -// HandleConfirmation implements NUDHandler.HandleConfirmation by following the -// logic defined in RFC 4861 section 7.2.5. +// handleConfirmation handles a neighbor confirmation as defined by +// RFC 4861 section 7.2.5. // -// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other -// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol -// should be deployed where preventing access to the broadcast segment might -// not be possible. SEND uses RSA key pairs to produce cryptographically -// generated addresses, as defined in RFC 3972, Cryptographically Generated -// Addresses (CGA). This ensures that the claimed source of an NDP message is -// the owner of the claimed address. -func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { +// Validation of the confirmation is expected to be handled by the caller. +func (n *neighborCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { n.mu.RLock() entry, ok := n.cache[addr] n.mu.RUnlock() @@ -309,3 +323,12 @@ func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { entry.mu.Unlock() } } + +func (n *neighborCache) nudConfig() (NUDConfigurations, tcpip.Error) { + return n.config(), nil +} + +func (n *neighborCache) setNUDConfig(c NUDConfigurations) tcpip.Error { + n.setConfig(c) + return nil +} diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 0f7925774..122888fcf 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -91,7 +91,6 @@ func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock state: NewNUDState(config, rng), cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), } - neigh.nic.neigh = neigh return neigh } @@ -212,7 +211,7 @@ func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ // fakeRequest emulates handling a response for a link address request. func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) { if entry, ok := r.entries.entryByAddr(addr); ok { - r.neigh.HandleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ + r.neigh.handleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, @@ -1473,7 +1472,7 @@ func TestNeighborCacheReplace(t *testing.T) { updatedLinkAddr = entry.LinkAddr } store.set(0, updatedLinkAddr) - neigh.HandleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ + neigh.handleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index c5c3d266b..5e5e0e6ca 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -231,7 +231,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e stats: makeNICStats(), } nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ - header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil), + header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil), } rng := rand.New(rand.NewSource(time.Now().UnixNano())) @@ -240,12 +240,13 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes) // Stub out the neighbor cache to verify deletion from the cache. - nic.neigh = &neighborCache{ + neigh := &neighborCache{ nic: &nic, state: nudState, cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), } - nic.neigh.cache[entryTestAddr1] = entry + neigh.cache[entryTestAddr1] = entry + nic.neighborTable = neigh return entry, &disp, &linkRes, clock } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0707c3ce2..c813b0da5 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "math/rand" "reflect" "sync/atomic" @@ -25,6 +24,21 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) +type neighborTable interface { + neighbors() ([]NeighborEntry, tcpip.Error) + addStaticEntry(tcpip.Address, tcpip.LinkAddress) + get(addr tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) + remove(tcpip.Address) tcpip.Error + removeAll() tcpip.Error + + handleProbe(tcpip.Address, tcpip.LinkAddress, LinkAddressResolver) + handleConfirmation(tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) + handleUpperLevelConfirmation(tcpip.Address) + + nudConfig() (NUDConfigurations, tcpip.Error) + setNUDConfig(NUDConfigurations) tcpip.Error +} + var _ NetworkInterface = (*NIC)(nil) // NIC represents a "network interface card" to which the networking stack is @@ -38,7 +52,6 @@ type NIC struct { context NICContext stats NICStats - neigh *neighborCache // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. @@ -54,7 +67,7 @@ type NIC struct { // complete. linkResQueue packetsPendingLinkResolution - linkAddrCache *linkAddrCache + neighborTable neighborTable mu struct { sync.RWMutex @@ -143,26 +156,20 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), } nic.linkResQueue.init(nic) - nic.linkAddrCache = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) - // Check for Neighbor Unreachability Detection support. - var nud NUDHandler - if ep.Capabilities()&CapabilityResolutionRequired != 0 && stack.useNeighborCache { - rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds())) - nic.neigh = &neighborCache{ - nic: nic, - state: NewNUDState(stack.nudConfigs, rng), - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), - } + resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0 - // An interface value that holds a nil pointer but non-nil type is not the - // same as the nil interface. Because of this, nud must only be assignd if - // nic.neigh is non-nil since a nil reference to a neighborCache is not - // valid. - // - // See https://golang.org/doc/faq#nil_error for more information. - nud = nic.neigh + if resolutionRequired { + if stack.useNeighborCache { + nic.neighborTable = &neighborCache{ + nic: nic, + state: NewNUDState(stack.nudConfigs, stack.randomGenerator), + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + } else { + nic.neighborTable = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts) + } } // Register supported packet and network endpoint protocols. @@ -173,11 +180,13 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC netNum := netProto.Number() nic.mu.packetEPs[netNum] = new(packetEndpointList) - netEP := netProto.NewEndpoint(nic, nic.linkAddrCache, nud, nic) + netEP := netProto.NewEndpoint(nic, nic) nic.networkEndpoints[netNum] = netEP - if r, ok := netEP.(LinkAddressResolver); ok { - nic.linkAddrResolvers[r.LinkAddressProtocol()] = r + if resolutionRequired { + if r, ok := netEP.(LinkAddressResolver); ok { + nic.linkAddrResolvers[r.LinkAddressProtocol()] = r + } } } @@ -596,8 +605,8 @@ func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error { } func (n *NIC) confirmReachable(addr tcpip.Address) { - if n := n.neigh; n != nil { - n.handleUpperLevelConfirmation(addr) + if n.neighborTable != nil { + n.neighborTable.handleUpperLevelConfirmation(addr) } } @@ -617,49 +626,44 @@ func (n *NIC) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.Netwo } func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - if n.neigh != nil { - entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve) - return entry.LinkAddr, ch, err + if n.neighborTable != nil { + return n.neighborTable.get(addr, linkRes, localAddr, onResolve) } - return n.linkAddrCache.get(addr, linkRes, localAddr, onResolve) + return "", nil, &tcpip.ErrNotSupported{} } func (n *NIC) neighbors() ([]NeighborEntry, tcpip.Error) { - if n.neigh == nil { - return nil, &tcpip.ErrNotSupported{} + if n.neighborTable != nil { + return n.neighborTable.neighbors() } - return n.neigh.entries(), nil + return nil, &tcpip.ErrNotSupported{} } func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) tcpip.Error { - if n.neigh == nil { - return &tcpip.ErrNotSupported{} + if n.neighborTable != nil { + n.neighborTable.addStaticEntry(addr, linkAddress) + return nil } - n.neigh.addStaticEntry(addr, linkAddress) - return nil + return &tcpip.ErrNotSupported{} } func (n *NIC) removeNeighbor(addr tcpip.Address) tcpip.Error { - if n.neigh == nil { - return &tcpip.ErrNotSupported{} + if n.neighborTable != nil { + return n.neighborTable.remove(addr) } - if !n.neigh.removeEntry(addr) { - return &tcpip.ErrBadAddress{} - } - return nil + return &tcpip.ErrNotSupported{} } func (n *NIC) clearNeighbors() tcpip.Error { - if n.neigh == nil { - return &tcpip.ErrNotSupported{} + if n.neighborTable != nil { + return n.neighborTable.removeAll() } - n.neigh.clear() - return nil + return &tcpip.ErrNotSupported{} } // joinGroup adds a new endpoint for the given multicast address, if none @@ -944,10 +948,11 @@ func (n *NIC) Name() string { // nudConfigs gets the NUD configurations for n. func (n *NIC) nudConfigs() (NUDConfigurations, tcpip.Error) { - if n.neigh == nil { - return NUDConfigurations{}, &tcpip.ErrNotSupported{} + if n.neighborTable != nil { + return n.neighborTable.nudConfig() } - return n.neigh.config(), nil + + return NUDConfigurations{}, &tcpip.ErrNotSupported{} } // setNUDConfigs sets the NUD configurations for n. @@ -955,12 +960,12 @@ func (n *NIC) nudConfigs() (NUDConfigurations, tcpip.Error) { // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. func (n *NIC) setNUDConfigs(c NUDConfigurations) tcpip.Error { - if n.neigh == nil { - return &tcpip.ErrNotSupported{} + if n.neighborTable != nil { + c.resetInvalidFields() + return n.neighborTable.setNUDConfig(c) } - c.resetInvalidFields() - n.neigh.setConfig(c) - return nil + + return &tcpip.ErrNotSupported{} } func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { @@ -996,3 +1001,17 @@ func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { n.mu.RUnlock() return n.Enabled() && ep.IsAssigned(spoofing) } + +// HandleNeighborProbe implements NetworkInterface. +func (n *NIC) HandleNeighborProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { + if n.neighborTable != nil { + n.neighborTable.handleProbe(addr, linkAddr, linkRes) + } +} + +// HandleNeighborConfirmation implements NetworkInterface. +func (n *NIC) HandleNeighborConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { + if n.neighborTable != nil { + n.neighborTable.handleConfirmation(addr, linkAddr, flags) + } +} diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 3564202d8..9992d6eb4 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -137,7 +137,7 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) } // NewEndpoint implements NetworkProtocol.NewEndpoint. -func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint { +func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ TransportDispatcher) NetworkEndpoint { e := &testIPv6Endpoint{ nic: nic, protocol: p, diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 77926e289..5a94e9ac6 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -161,21 +161,6 @@ type ReachabilityConfirmationFlags struct { IsRouter bool } -// NUDHandler communicates external events to the Neighbor Unreachability -// Detection state machine, which is implemented per-interface. This is used by -// network endpoints to inform the Neighbor Cache of probes and confirmations. -type NUDHandler interface { - // HandleProbe processes an incoming neighbor probe (e.g. ARP request or - // Neighbor Solicitation for ARP or NDP, respectively). Validation of the - // probe needs to be performed before calling this function since the - // Neighbor Cache doesn't have access to view the NIC's assigned addresses. - HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) - - // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP - // reply or Neighbor Advertisement for ARP or NDP, respectively). - HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) -} - // NUDConfigurations is the NUD configurations for the netstack. This is used // by the neighbor cache to operate the NUD state machine on each device in the // local network. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 64b5627e1..c652c2bd7 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -530,6 +530,17 @@ type NetworkInterface interface { // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) + + // HandleNeighborProbe processes an incoming neighbor probe (e.g. ARP + // request or NDP Neighbor Solicitation). + // + // HandleNeighborProbe assumes that the probe is valid for the network + // interface the probe was received on. + HandleNeighborProbe(tcpip.Address, tcpip.LinkAddress, LinkAddressResolver) + + // HandleNeighborConfirmation processes an incoming neighbor confirmation + // (e.g. ARP reply or NDP Neighbor Advertisement). + HandleNeighborConfirmation(tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) } // LinkResolvableNetworkEndpoint handles link resolution events. @@ -649,7 +660,7 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint + NewEndpoint(nic NetworkInterface, dispatcher TransportDispatcher) NetworkEndpoint // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -843,12 +854,6 @@ type LinkAddressResolver interface { LinkAddressProtocol() tcpip.NetworkProtocolNumber } -// A LinkAddressCache caches link addresses. -type LinkAddressCache interface { - // AddLinkAddress adds a link address to the cache. - AddLinkAddress(addr tcpip.Address, linkAddr tcpip.LinkAddress) -} - // RawFactory produces endpoints for writing various types of raw packets. type RawFactory interface { // NewUnassociatedEndpoint produces endpoints for writing packets not diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 66bf22823..73db6e031 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1526,20 +1526,6 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) tcpip.Error { return nil } -// AddLinkAddress adds a link address for the neighbor on the specified NIC. -func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { - s.mu.RLock() - defer s.mu.RUnlock() - - nic, ok := s.nics[nicID] - if !ok { - return &tcpip.ErrUnknownNICID{} - } - - nic.linkAddrCache.AddLinkAddress(neighbor, linkAddr) - return nil -} - // LinkResolutionResult is the result of a link address resolution attempt. type LinkResolutionResult struct { LinkAddress tcpip.LinkAddress diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 41f95811f..a166c0502 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -243,7 +243,7 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &fakeNetworkEndpoint{ nic: nic, proto: f, @@ -4391,7 +4391,9 @@ func TestStaticGetLinkAddress(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, }) - if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil { + e := channel.New(0, 0, "") + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 64c5298d3..5d81dbb94 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1614,7 +1614,7 @@ func TestTTL(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{p}, }) - ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil, nil, nil) + ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil) wantTTL = ep.DefaultTTL() ep.Close() } |