From e5ece9aea730c105ab336e6bd2858322686a5708 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 29 Sep 2020 19:44:42 -0700 Subject: Return permanent addresses when NIC is down Test: stack_test.TestGetMainNICAddressWhenNICDisabled PiperOrigin-RevId: 334513286 --- pkg/tcpip/stack/addressable_endpoint_state.go | 111 ++++++++++++++------- pkg/tcpip/stack/addressable_endpoint_state_test.go | 11 +- pkg/tcpip/stack/nic.go | 22 ++-- pkg/tcpip/stack/registration.go | 13 ++- pkg/tcpip/stack/stack_test.go | 63 +++++++++++- 5 files changed, 160 insertions(+), 60 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 270ac4977..db8ac1c2b 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -412,6 +412,60 @@ func (a *AddressableEndpointState) decAddressRefLocked(addrState *addressState) a.releaseAddressStateLocked(addrState) } +// MainAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) MainAddress() tcpip.AddressWithPrefix { + a.mu.RLock() + defer a.mu.RUnlock() + + ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool { + return ep.GetKind() == Permanent + }) + if ep == nil { + return tcpip.AddressWithPrefix{} + } + + addr := ep.AddressWithPrefix() + a.decAddressRefLocked(ep) + return addr +} + +// acquirePrimaryAddressRLocked returns an acquired primary address that is +// valid according to isValid. +// +// Precondition: e.mu must be read locked +func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*addressState) bool) *addressState { + var deprecatedEndpoint *addressState + for _, ep := range a.mu.primary { + if !isValid(ep) { + continue + } + + if !ep.Deprecated() { + if ep.IncRef() { + // ep is not deprecated, so return it immediately. + // + // If we kept track of a deprecated endpoint, decrement its reference + // count since it was incremented when we decided to keep track of it. + if deprecatedEndpoint != nil { + a.decAddressRefLocked(deprecatedEndpoint) + deprecatedEndpoint = nil + } + + return ep + } + } else if deprecatedEndpoint == nil && ep.IncRef() { + // We prefer an endpoint that is not deprecated, but we keep track of + // ep in case a doesn't have any non-deprecated endpoints. + // + // If we end up finding a more preferred endpoint, ep's reference count + // will be decremented. + deprecatedEndpoint = ep + } + } + + return deprecatedEndpoint +} + // AcquireAssignedAddress implements AddressableEndpoint. func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { a.mu.Lock() @@ -461,47 +515,34 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres return ep } -// AcquirePrimaryAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) AcquirePrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint { +// AcquireOutgoingPrimaryAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint { a.mu.RLock() defer a.mu.RUnlock() - var deprecatedEndpoint *addressState - for _, ep := range a.mu.primary { - if !ep.IsAssigned(allowExpired) { - continue - } - - if !ep.Deprecated() { - if ep.IncRef() { - // ep is not deprecated, so return it immediately. - // - // If we kept track of a deprecated endpoint, decrement its reference - // count since it was incremented when we decided to keep track of it. - if deprecatedEndpoint != nil { - a.decAddressRefLocked(deprecatedEndpoint) - deprecatedEndpoint = nil - } - - return ep - } - } else if deprecatedEndpoint == nil && ep.IncRef() { - // We prefer an endpoint that is not deprecated, but we keep track of - // ep in case a doesn't have any non-deprecated endpoints. - // - // If we end up finding a more preferred endpoint, ep's reference count - // will be decremented. - deprecatedEndpoint = ep - } - } + ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool { + return ep.IsAssigned(allowExpired) + }) - // a doesn't have any valid non-deprecated endpoints, so return - // deprecatedEndpoint (which may be nil if a doesn't have any valid deprecated - // endpoints either). - if deprecatedEndpoint == nil { + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since acquirePrimaryAddressRLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { return nil } - return deprecatedEndpoint + + return ep } // PrimaryAddresses implements AddressableEndpoint. diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go index de4e0d7b1..26787d0a3 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -24,8 +24,13 @@ import ( // TestAddressableEndpointStateCleanup tests that cleaning up an addressable // endpoint state removes permanent addresses and leaves groups. func TestAddressableEndpointStateCleanup(t *testing.T) { + var ep fakeNetworkEndpoint + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + var s stack.AddressableEndpointState - s.Init(&fakeNetworkEndpoint{}) + s.Init(&ep) addr := tcpip.AddressWithPrefix{ Address: "\x01", @@ -43,7 +48,7 @@ func TestAddressableEndpointStateCleanup(t *testing.T) { { ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) if ep == nil { - t.Fatalf("got s.AcquireAssignedAddress(%s) = nil, want = non-nil", addr.Address) + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = nil, want = non-nil", addr.Address) } ep.DecRef() } @@ -63,7 +68,7 @@ func TestAddressableEndpointStateCleanup(t *testing.T) { ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) if ep != nil { ep.DecRef() - t.Fatalf("got s.AcquireAssignedAddress(%s) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) } } if s.IsInGroup(group) { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 212c6edae..23022292c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -244,22 +244,19 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -// primaryEndpoint will return the first non-deprecated endpoint if such an -// endpoint exists for the given protocol and remoteAddr. If no non-deprecated -// endpoint exists, the first deprecated endpoint will be returned. -// -// If an IPv6 primary endpoint is requested, Source Address Selection (as -// defined by RFC 6724 section 5) will be performed. +// primaryAddress returns an address that can be used to communicate with +// remoteAddr. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { n.mu.RLock() - defer n.mu.RUnlock() + spoofing := n.mu.spoofing + n.mu.RUnlock() ep, ok := n.networkEndpoints[protocol] if !ok { return nil } - return ep.AcquirePrimaryAddress(remoteAddr, n.mu.spoofing) + return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) } type getAddressBehaviour int @@ -360,13 +357,12 @@ func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress { // address exists. If no non-deprecated address exists, the first deprecated // address will be returned. func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { - addressEndpoint := n.primaryEndpoint(proto, "") - if addressEndpoint == nil { + ep, ok := n.networkEndpoints[proto] + if !ok { return tcpip.AddressWithPrefix{} } - addr := addressEndpoint.AddressWithPrefix() - addressEndpoint.DecRef() - return addr + + return ep.MainAddress() } // removeAddress removes an address from n. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 567e1904e..b6f823b54 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -435,7 +435,10 @@ type AddressableEndpoint interface { // permanent address. RemovePermanentAddress(addr tcpip.Address) *tcpip.Error - // AcquireAssignedAddress returns an AddressEndpoint for the passed address + // MainAddress returns the endpoint's primary permanent address. + MainAddress() tcpip.AddressWithPrefix + + // AcquireAssignedAddress returns an address endpoint for the passed address // that is considered bound to the endpoint, optionally creating a temporary // endpoint if requested and no existing address exists. // @@ -444,15 +447,15 @@ type AddressableEndpoint interface { // Returns nil if the specified address is not local to this endpoint. AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint - // AcquirePrimaryAddress returns a primary endpoint to use when communicating - // with the passed remote address. + // AcquireOutgoingPrimaryAddress returns a primary address that may be used as + // a source address when sending packets to the passed remote address. // // If allowExpired is true, expired addresses may be returned. // // The returned endpoint's reference count is incremented. // - // Returns nil if a primary endpoint is not available. - AcquirePrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint + // Returns nil if a primary address is not available. + AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint // PrimaryAddresses returns the primary addresses. PrimaryAddresses() []tcpip.AddressWithPrefix diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index fda22c550..aa20f750b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -71,21 +71,36 @@ const ( type fakeNetworkEndpoint struct { stack.AddressableEndpointState + mu struct { + sync.RWMutex + + enabled bool + } + nicID tcpip.NICID proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher ep stack.LinkEndpoint } -func (*fakeNetworkEndpoint) Enable() *tcpip.Error { +func (f *fakeNetworkEndpoint) Enable() *tcpip.Error { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.enabled = true return nil } -func (*fakeNetworkEndpoint) Enabled() bool { - return true +func (f *fakeNetworkEndpoint) Enabled() bool { + f.mu.RLock() + defer f.mu.RUnlock() + return f.mu.enabled } -func (*fakeNetworkEndpoint) Disable() {} +func (f *fakeNetworkEndpoint) Disable() { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.enabled = false +} func (f *fakeNetworkEndpoint) MTU() uint32 { return f.ep.MTU() - uint32(f.MaxHeaderLength()) @@ -3620,3 +3635,43 @@ func TestGetNetworkEndpoint(t *testing.T) { }) } } + +func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + }) + + if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: 8, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err) + } + + // Check that we get the right initial address and prefix length. + if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) + } else if gotAddr != protocolAddress.AddressWithPrefix { + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + } + + // Should still get the address when the NIC is diabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) + } else if gotAddr != protocolAddress.AddressWithPrefix { + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + } +} -- cgit v1.2.3