summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/network/arp/arp.go8
-rw-r--r--pkg/tcpip/network/ip_test.go4
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go5
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go4
-rw-r--r--pkg/tcpip/stack/nic.go14
-rw-r--r--pkg/tcpip/stack/registration.go8
-rw-r--r--pkg/tcpip/stack/stack.go3
-rw-r--r--pkg/tcpip/stack/stack_test.go43
8 files changed, 70 insertions, 19 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index ae0461a6d..43a4b7cac 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -305,8 +305,6 @@ func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
- nicID := e.nic.ID()
-
stats := e.stats.arp
if len(remoteLinkAddr) == 0 {
@@ -314,9 +312,9 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
if len(localAddr) == 0 {
- addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
- if !ok {
- return &tcpip.ErrUnknownNICID{}
+ addr, err := e.nic.PrimaryAddress(header.IPv4ProtocolNumber)
+ if err != nil {
+ return err
}
if len(addr.Address) == 0 {
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index aee1652fa..a4edc69c7 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -335,6 +335,10 @@ func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tc
return nil
}
+func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
+ return tcpip.AddressWithPrefix{}, nil
+}
+
func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
return false
}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 6344a3e09..8059e0690 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -899,13 +899,16 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
if len(localAddr) == 0 {
+ // Find an address that we can use as our source address.
addressEndpoint := e.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */)
if addressEndpoint == nil {
return &tcpip.ErrNetworkUnreachable{}
}
localAddr = addressEndpoint.AddressWithPrefix().Address
- } else if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, localAddr) == 0 {
+ addressEndpoint.DecRef()
+ } else if !e.checkLocalAddress(localAddr) {
+ // The provided local address is not assigned to us.
return &tcpip.ErrBadLocalAddress{}
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index d4e63710c..47d713f88 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -155,6 +155,10 @@ func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber,
return nil
}
+func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
+ return tcpip.AddressWithPrefix{}, nil
+}
+
func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
return false
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 62f7c880e..ca15c0691 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -568,23 +568,19 @@ func (n *nic) primaryAddresses() []tcpip.ProtocolAddress {
return addrs
}
-// primaryAddress returns the primary address associated with this NIC.
-//
-// primaryAddress will return the first non-deprecated address if such an
-// address exists. If no non-deprecated address exists, the first deprecated
-// address will be returned.
-func (n *nic) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
+// PrimaryAddress implements NetworkInterface.
+func (n *nic) PrimaryAddress(proto tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
ep, ok := n.networkEndpoints[proto]
if !ok {
- return tcpip.AddressWithPrefix{}
+ return tcpip.AddressWithPrefix{}, &tcpip.ErrUnknownProtocol{}
}
addressableEndpoint, ok := ep.(AddressableEndpoint)
if !ok {
- return tcpip.AddressWithPrefix{}
+ return tcpip.AddressWithPrefix{}, &tcpip.ErrNotSupported{}
}
- return addressableEndpoint.MainAddress()
+ return addressableEndpoint.MainAddress(), nil
}
// removeAddress removes an address from n.
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 85f0f471a..ff3a385e1 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -525,6 +525,14 @@ type NetworkInterface interface {
// assigned to it.
Spoofing() bool
+ // PrimaryAddress returns the primary address associated with the interface.
+ //
+ // PrimaryAddress will return the first non-deprecated address if such an
+ // address exists. If no non-deprecated addresses exist, the first deprecated
+ // address will be returned. If no deprecated addresses exist, the zero value
+ // will be returned.
+ PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error)
+
// CheckLocalAddress returns true if the address exists on the interface.
CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 53370c354..1fffe9274 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1224,7 +1224,8 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol
return tcpip.AddressWithPrefix{}, false
}
- return nic.primaryAddress(protocol), true
+ addr, err := nic.PrimaryAddress(protocol)
+ return addr, err == nil
}
func (s *Stack) getAddressEP(nic *nic, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint {
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 880219007..0d95bc7d6 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -1926,6 +1926,39 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
}
}
+func TestGetMainNICAddressErrors(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ // Sanity check with a successful call.
+ if addr, ok := s.GetMainNICAddress(nicID, ipv4.ProtocolNumber); !ok {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, false), want = (_, true)", nicID, ipv4.ProtocolNumber, addr)
+ } else if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, _), want = (%s, _)", nicID, ipv4.ProtocolNumber, addr, want)
+ }
+
+ const unknownNICID = nicID + 1
+ if addr, ok := s.GetMainNICAddress(unknownNICID, ipv4.ProtocolNumber); ok {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, true), want = (_, false)", unknownNICID, ipv4.ProtocolNumber, addr)
+ }
+
+ // ARP is not an addressable network endpoint.
+ if addr, ok := s.GetMainNICAddress(nicID, arp.ProtocolNumber); ok {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, true), want = (_, false)", nicID, arp.ProtocolNumber, addr)
+ }
+
+ const unknownProtocolNumber = 1234
+ if addr, ok := s.GetMainNICAddress(nicID, unknownProtocolNumber); ok {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, true), want = (_, false)", nicID, unknownProtocolNumber, addr)
+ }
+}
+
func TestGetMainNICAddressAddRemove(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
@@ -2507,11 +2540,15 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
}
}
- // Check that we get no address after removal.
- if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil {
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, expectedMainAddr); err != nil {
t.Fatal(err)
}
- if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, expectedMainAddr); err != nil {
+
+ // Disabling the NIC should remove the auto-generated address.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
t.Fatal(err)
}
})