summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go13
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go21
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go2
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go111
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go11
-rw-r--r--pkg/tcpip/stack/nic.go22
-rw-r--r--pkg/tcpip/stack/registration.go13
-rw-r--r--pkg/tcpip/stack/stack_test.go63
8 files changed, 185 insertions, 71 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 746cf520d..ad7a767a4 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -595,6 +595,13 @@ func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
return e.mu.addressableEndpointState.RemovePermanentAddress(addr)
}
+// MainAddress implements stack.AddressableEndpoint.
+func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.MainAddress()
+}
+
// AcquireAssignedAddress implements stack.AddressableEndpoint.
func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
e.mu.Lock()
@@ -625,11 +632,11 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo
return addressEndpoint
}
-// AcquirePrimaryAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AcquirePrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mu.addressableEndpointState.AcquirePrimaryAddress(remoteAddr, allowExpired)
+ return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
}
// PrimaryAddresses implements stack.AddressableEndpoint.
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 3705f56a2..aff4e1425 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -922,6 +922,13 @@ func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpo
return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr)
}
+// MainAddress implements stack.AddressableEndpoint.
+func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.MainAddress()
+}
+
// AcquireAssignedAddress implements stack.AddressableEndpoint.
func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
e.mu.Lock()
@@ -937,18 +944,18 @@ func (e *endpoint) acquireAddressOrCreateTempLocked(localAddr tcpip.Address, all
return e.mu.addressableEndpointState.AcquireAssignedAddress(localAddr, allowTemp, tempPEB)
}
-// AcquirePrimaryAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AcquirePrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.acquirePrimaryAddressRLocked(remoteAddr, allowExpired)
+ return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
}
-// acquirePrimaryAddressRLocked is like AcquirePrimaryAddress but with locking
-// requirements.
+// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
+// but with locking requirements.
//
// Precondition: e.mu must be read locked.
-func (e *endpoint) acquirePrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
// addrCandidate is a candidate for Source Address Selection, as per
// RFC 6724 section 5.
type addrCandidate struct {
@@ -957,7 +964,7 @@ func (e *endpoint) acquirePrimaryAddressRLocked(remoteAddr tcpip.Address, allowE
}
if len(remoteAddr) == 0 {
- return e.mu.addressableEndpointState.AcquirePrimaryAddress(remoteAddr, allowExpired)
+ return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
}
// Create a candidate set of available addresses we can potentially use as a
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 84c082852..48a4c65e3 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -1891,7 +1891,7 @@ func (ndp *ndpState) startSolicitingRouters() {
// As per RFC 4861 section 4.1, the source of the RS is an address assigned
// to the sending interface, or the unspecified address if no address is
// assigned to the sending interface.
- addressEndpoint := ndp.ep.acquirePrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false)
+ addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false)
if addressEndpoint == nil {
// Incase this ends up creating a new temporary address, we need to hold
// onto the endpoint until a route is obtained. If we decrement the
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 270ac4977..db8ac1c2b 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -412,6 +412,60 @@ func (a *AddressableEndpointState) decAddressRefLocked(addrState *addressState)
a.releaseAddressStateLocked(addrState)
}
+// MainAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) MainAddress() tcpip.AddressWithPrefix {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool {
+ return ep.GetKind() == Permanent
+ })
+ if ep == nil {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ addr := ep.AddressWithPrefix()
+ a.decAddressRefLocked(ep)
+ return addr
+}
+
+// acquirePrimaryAddressRLocked returns an acquired primary address that is
+// valid according to isValid.
+//
+// Precondition: e.mu must be read locked
+func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*addressState) bool) *addressState {
+ var deprecatedEndpoint *addressState
+ for _, ep := range a.mu.primary {
+ if !isValid(ep) {
+ continue
+ }
+
+ if !ep.Deprecated() {
+ if ep.IncRef() {
+ // ep is not deprecated, so return it immediately.
+ //
+ // If we kept track of a deprecated endpoint, decrement its reference
+ // count since it was incremented when we decided to keep track of it.
+ if deprecatedEndpoint != nil {
+ a.decAddressRefLocked(deprecatedEndpoint)
+ deprecatedEndpoint = nil
+ }
+
+ return ep
+ }
+ } else if deprecatedEndpoint == nil && ep.IncRef() {
+ // We prefer an endpoint that is not deprecated, but we keep track of
+ // ep in case a doesn't have any non-deprecated endpoints.
+ //
+ // If we end up finding a more preferred endpoint, ep's reference count
+ // will be decremented.
+ deprecatedEndpoint = ep
+ }
+ }
+
+ return deprecatedEndpoint
+}
+
// AcquireAssignedAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
a.mu.Lock()
@@ -461,47 +515,34 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return ep
}
-// AcquirePrimaryAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AcquirePrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
+// AcquireOutgoingPrimaryAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
a.mu.RLock()
defer a.mu.RUnlock()
- var deprecatedEndpoint *addressState
- for _, ep := range a.mu.primary {
- if !ep.IsAssigned(allowExpired) {
- continue
- }
-
- if !ep.Deprecated() {
- if ep.IncRef() {
- // ep is not deprecated, so return it immediately.
- //
- // If we kept track of a deprecated endpoint, decrement its reference
- // count since it was incremented when we decided to keep track of it.
- if deprecatedEndpoint != nil {
- a.decAddressRefLocked(deprecatedEndpoint)
- deprecatedEndpoint = nil
- }
-
- return ep
- }
- } else if deprecatedEndpoint == nil && ep.IncRef() {
- // We prefer an endpoint that is not deprecated, but we keep track of
- // ep in case a doesn't have any non-deprecated endpoints.
- //
- // If we end up finding a more preferred endpoint, ep's reference count
- // will be decremented.
- deprecatedEndpoint = ep
- }
- }
+ ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool {
+ return ep.IsAssigned(allowExpired)
+ })
- // a doesn't have any valid non-deprecated endpoints, so return
- // deprecatedEndpoint (which may be nil if a doesn't have any valid deprecated
- // endpoints either).
- if deprecatedEndpoint == nil {
+ // From https://golang.org/doc/faq#nil_error:
+ //
+ // Under the covers, interfaces are implemented as two elements, a type T and
+ // a value V.
+ //
+ // An interface value is nil only if the V and T are both unset, (T=nil, V is
+ // not set), In particular, a nil interface will always hold a nil type. If we
+ // store a nil pointer of type *int inside an interface value, the inner type
+ // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
+ // an interface value will therefore be non-nil even when the pointer value V
+ // inside is nil.
+ //
+ // Since acquirePrimaryAddressRLocked returns a nil value with a non-nil type,
+ // we need to explicitly return nil below if ep is (a typed) nil.
+ if ep == nil {
return nil
}
- return deprecatedEndpoint
+
+ return ep
}
// PrimaryAddresses implements AddressableEndpoint.
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index de4e0d7b1..26787d0a3 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -24,8 +24,13 @@ import (
// TestAddressableEndpointStateCleanup tests that cleaning up an addressable
// endpoint state removes permanent addresses and leaves groups.
func TestAddressableEndpointStateCleanup(t *testing.T) {
+ var ep fakeNetworkEndpoint
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
var s stack.AddressableEndpointState
- s.Init(&fakeNetworkEndpoint{})
+ s.Init(&ep)
addr := tcpip.AddressWithPrefix{
Address: "\x01",
@@ -43,7 +48,7 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
{
ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
if ep == nil {
- t.Fatalf("got s.AcquireAssignedAddress(%s) = nil, want = non-nil", addr.Address)
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = nil, want = non-nil", addr.Address)
}
ep.DecRef()
}
@@ -63,7 +68,7 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
if ep != nil {
ep.DecRef()
- t.Fatalf("got s.AcquireAssignedAddress(%s) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
}
}
if s.IsInGroup(group) {
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 212c6edae..23022292c 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -244,22 +244,19 @@ func (n *NIC) setSpoofing(enable bool) {
n.mu.Unlock()
}
-// primaryEndpoint will return the first non-deprecated endpoint if such an
-// endpoint exists for the given protocol and remoteAddr. If no non-deprecated
-// endpoint exists, the first deprecated endpoint will be returned.
-//
-// If an IPv6 primary endpoint is requested, Source Address Selection (as
-// defined by RFC 6724 section 5) will be performed.
+// primaryAddress returns an address that can be used to communicate with
+// remoteAddr.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint {
n.mu.RLock()
- defer n.mu.RUnlock()
+ spoofing := n.mu.spoofing
+ n.mu.RUnlock()
ep, ok := n.networkEndpoints[protocol]
if !ok {
return nil
}
- return ep.AcquirePrimaryAddress(remoteAddr, n.mu.spoofing)
+ return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing)
}
type getAddressBehaviour int
@@ -360,13 +357,12 @@ func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress {
// address exists. If no non-deprecated address exists, the first deprecated
// address will be returned.
func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
- addressEndpoint := n.primaryEndpoint(proto, "")
- if addressEndpoint == nil {
+ ep, ok := n.networkEndpoints[proto]
+ if !ok {
return tcpip.AddressWithPrefix{}
}
- addr := addressEndpoint.AddressWithPrefix()
- addressEndpoint.DecRef()
- return addr
+
+ return ep.MainAddress()
}
// removeAddress removes an address from n.
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 567e1904e..b6f823b54 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -435,7 +435,10 @@ type AddressableEndpoint interface {
// permanent address.
RemovePermanentAddress(addr tcpip.Address) *tcpip.Error
- // AcquireAssignedAddress returns an AddressEndpoint for the passed address
+ // MainAddress returns the endpoint's primary permanent address.
+ MainAddress() tcpip.AddressWithPrefix
+
+ // AcquireAssignedAddress returns an address endpoint for the passed address
// that is considered bound to the endpoint, optionally creating a temporary
// endpoint if requested and no existing address exists.
//
@@ -444,15 +447,15 @@ type AddressableEndpoint interface {
// Returns nil if the specified address is not local to this endpoint.
AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint
- // AcquirePrimaryAddress returns a primary endpoint to use when communicating
- // with the passed remote address.
+ // AcquireOutgoingPrimaryAddress returns a primary address that may be used as
+ // a source address when sending packets to the passed remote address.
//
// If allowExpired is true, expired addresses may be returned.
//
// The returned endpoint's reference count is incremented.
//
- // Returns nil if a primary endpoint is not available.
- AcquirePrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint
+ // Returns nil if a primary address is not available.
+ AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint
// PrimaryAddresses returns the primary addresses.
PrimaryAddresses() []tcpip.AddressWithPrefix
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index fda22c550..aa20f750b 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -71,21 +71,36 @@ const (
type fakeNetworkEndpoint struct {
stack.AddressableEndpointState
+ mu struct {
+ sync.RWMutex
+
+ enabled bool
+ }
+
nicID tcpip.NICID
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
ep stack.LinkEndpoint
}
-func (*fakeNetworkEndpoint) Enable() *tcpip.Error {
+func (f *fakeNetworkEndpoint) Enable() *tcpip.Error {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.enabled = true
return nil
}
-func (*fakeNetworkEndpoint) Enabled() bool {
- return true
+func (f *fakeNetworkEndpoint) Enabled() bool {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ return f.mu.enabled
}
-func (*fakeNetworkEndpoint) Disable() {}
+func (f *fakeNetworkEndpoint) Disable() {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.enabled = false
+}
func (f *fakeNetworkEndpoint) MTU() uint32 {
return f.ep.MTU() - uint32(f.MaxHeaderLength())
@@ -3620,3 +3635,43 @@ func TestGetNetworkEndpoint(t *testing.T) {
})
}
}
+
+func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ })
+
+ if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err)
+ }
+
+ // Check that we get the right initial address and prefix length.
+ if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
+ } else if gotAddr != protocolAddress.AddressWithPrefix {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ }
+
+ // Should still get the address when the NIC is diabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
+ } else if gotAddr != protocolAddress.AddressWithPrefix {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ }
+}