From ebd7c1b889e5d212f4a694d3addbada241936e8e Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 16 Mar 2021 10:28:04 -0700 Subject: Do not call into Stack from LinkAddressRequest Calling into the stack from LinkAddressRequest is not needed as we already have a reference to the network endpoint (IPv6) or network interface (IPv4/ARP). PiperOrigin-RevId: 363213973 --- pkg/tcpip/network/arp/arp.go | 8 +++---- pkg/tcpip/network/ip_test.go | 4 ++++ pkg/tcpip/network/ipv6/icmp.go | 5 ++++- pkg/tcpip/network/ipv6/icmp_test.go | 4 ++++ pkg/tcpip/stack/nic.go | 14 +++++------- pkg/tcpip/stack/registration.go | 8 +++++++ pkg/tcpip/stack/stack.go | 3 ++- pkg/tcpip/stack/stack_test.go | 43 ++++++++++++++++++++++++++++++++++--- 8 files changed, 70 insertions(+), 19 deletions(-) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index ae0461a6d..43a4b7cac 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -305,8 +305,6 @@ func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { - nicID := e.nic.ID() - stats := e.stats.arp if len(remoteLinkAddr) == 0 { @@ -314,9 +312,9 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } if len(localAddr) == 0 { - addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) - if !ok { - return &tcpip.ErrUnknownNICID{} + addr, err := e.nic.PrimaryAddress(header.IPv4ProtocolNumber) + if err != nil { + return err } if len(addr.Address) == 0 { diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index aee1652fa..a4edc69c7 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -335,6 +335,10 @@ func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tc return nil } +func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) { + return tcpip.AddressWithPrefix{}, nil +} + func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { return false } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 6344a3e09..8059e0690 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -899,13 +899,16 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } if len(localAddr) == 0 { + // Find an address that we can use as our source address. addressEndpoint := e.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */) if addressEndpoint == nil { return &tcpip.ErrNetworkUnreachable{} } localAddr = addressEndpoint.AddressWithPrefix().Address - } else if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, localAddr) == 0 { + addressEndpoint.DecRef() + } else if !e.checkLocalAddress(localAddr) { + // The provided local address is not assigned to us. return &tcpip.ErrBadLocalAddress{} } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d4e63710c..47d713f88 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -155,6 +155,10 @@ func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, return nil } +func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) { + return tcpip.AddressWithPrefix{}, nil +} + func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { return false } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 62f7c880e..ca15c0691 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -568,23 +568,19 @@ func (n *nic) primaryAddresses() []tcpip.ProtocolAddress { return addrs } -// primaryAddress returns the primary address associated with this NIC. -// -// primaryAddress will return the first non-deprecated address if such an -// address exists. If no non-deprecated address exists, the first deprecated -// address will be returned. -func (n *nic) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { +// PrimaryAddress implements NetworkInterface. +func (n *nic) PrimaryAddress(proto tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) { ep, ok := n.networkEndpoints[proto] if !ok { - return tcpip.AddressWithPrefix{} + return tcpip.AddressWithPrefix{}, &tcpip.ErrUnknownProtocol{} } addressableEndpoint, ok := ep.(AddressableEndpoint) if !ok { - return tcpip.AddressWithPrefix{} + return tcpip.AddressWithPrefix{}, &tcpip.ErrNotSupported{} } - return addressableEndpoint.MainAddress() + return addressableEndpoint.MainAddress(), nil } // removeAddress removes an address from n. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 85f0f471a..ff3a385e1 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -525,6 +525,14 @@ type NetworkInterface interface { // assigned to it. Spoofing() bool + // PrimaryAddress returns the primary address associated with the interface. + // + // PrimaryAddress will return the first non-deprecated address if such an + // address exists. If no non-deprecated addresses exist, the first deprecated + // address will be returned. If no deprecated addresses exist, the zero value + // will be returned. + PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) + // CheckLocalAddress returns true if the address exists on the interface. CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 53370c354..1fffe9274 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1224,7 +1224,8 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return tcpip.AddressWithPrefix{}, false } - return nic.primaryAddress(protocol), true + addr, err := nic.PrimaryAddress(protocol) + return addr, err == nil } func (s *Stack) getAddressEP(nic *nic, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 880219007..0d95bc7d6 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1926,6 +1926,39 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { } } +func TestGetMainNICAddressErrors(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + // Sanity check with a successful call. + if addr, ok := s.GetMainNICAddress(nicID, ipv4.ProtocolNumber); !ok { + t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, false), want = (_, true)", nicID, ipv4.ProtocolNumber, addr) + } else if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, _), want = (%s, _)", nicID, ipv4.ProtocolNumber, addr, want) + } + + const unknownNICID = nicID + 1 + if addr, ok := s.GetMainNICAddress(unknownNICID, ipv4.ProtocolNumber); ok { + t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, true), want = (_, false)", unknownNICID, ipv4.ProtocolNumber, addr) + } + + // ARP is not an addressable network endpoint. + if addr, ok := s.GetMainNICAddress(nicID, arp.ProtocolNumber); ok { + t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, true), want = (_, false)", nicID, arp.ProtocolNumber, addr) + } + + const unknownProtocolNumber = 1234 + if addr, ok := s.GetMainNICAddress(nicID, unknownProtocolNumber); ok { + t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, true), want = (_, false)", nicID, unknownProtocolNumber, addr) + } +} + func TestGetMainNICAddressAddRemove(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, @@ -2507,11 +2540,15 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { } } - // Check that we get no address after removal. - if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil { + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, expectedMainAddr); err != nil { t.Fatal(err) } - if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, expectedMainAddr); err != nil { + + // Disabling the NIC should remove the auto-generated address. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Fatal(err) } }) -- cgit v1.2.3