summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/ndp.go96
-rw-r--r--pkg/tcpip/stack/nic.go38
-rw-r--r--pkg/tcpip/stack/nic_test.go257
-rw-r--r--pkg/tcpip/stack/route.go4
-rw-r--r--pkg/tcpip/stack/stack.go7
6 files changed, 364 insertions, 39 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index f71073207..afca925ad 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -110,5 +110,6 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
],
)
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index ae7a8f740..e28c23d66 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -467,8 +467,17 @@ type ndpState struct {
// The default routers discovered through Router Advertisements.
defaultRouters map[tcpip.Address]defaultRouterState
- // The timer used to send the next router solicitation message.
- rtrSolicitTimer *time.Timer
+ rtrSolicit struct {
+ // The timer used to send the next router solicitation message.
+ timer *time.Timer
+
+ // Used to let the Router Solicitation timer know that it has been stopped.
+ //
+ // Must only be read from or written to while protected by the lock of
+ // the NIC this ndpState is associated with. MUST be set when the timer is
+ // set.
+ done *bool
+ }
// The on-link prefixes discovered through Router Advertisements' Prefix
// Information option.
@@ -648,13 +657,14 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// as starting a goroutine but we use a timer that fires immediately so we can
// reset it for the next DAD iteration.
timer = time.AfterFunc(0, func() {
- ndp.nic.mu.RLock()
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+
if done {
// If we reach this point, it means that the DAD timer fired after
// another goroutine already obtained the NIC lock and stopped DAD
// before this function obtained the NIC lock. Simply return here and do
// nothing further.
- ndp.nic.mu.RUnlock()
return
}
@@ -665,15 +675,23 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
}
dadDone := remaining == 0
- ndp.nic.mu.RUnlock()
var err *tcpip.Error
if !dadDone {
- err = ndp.sendDADPacket(addr)
+ // Use the unspecified address as the source address when performing DAD.
+ ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
+
+ // Do not hold the lock when sending packets which may be a long running
+ // task or may block link address resolution. We know this is safe
+ // because immediately after obtaining the lock again, we check if DAD
+ // has been stopped before doing any work with the NIC. Note, DAD would be
+ // stopped if the NIC was disabled or removed, or if the address was
+ // removed.
+ ndp.nic.mu.Unlock()
+ err = ndp.sendDADPacket(addr, ref)
+ ndp.nic.mu.Lock()
}
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
if done {
// If we reach this point, it means that DAD was stopped after we released
// the NIC's read lock and before we obtained the write lock.
@@ -721,17 +739,24 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// addr.
//
// addr must be a tentative IPv6 address on ndp's NIC.
-func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
+//
+// The NIC ndp belongs to MUST NOT be locked.
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- // Use the unspecified address as the source address when performing DAD.
- ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
- r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ r := makeRoute(header.IPv6ProtocolNumber, ref.ep.ID().LocalAddress, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
// Route should resolve immediately since snmc is a multicast address so a
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the NIC is not locked,
+ // the NIC could have been disabled or removed by another goroutine.
+ if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
+ return err
+ }
+
panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err))
} else if c != nil {
panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()))
@@ -1816,7 +1841,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
//
// The NIC ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
- if ndp.rtrSolicitTimer != nil {
+ if ndp.rtrSolicit.timer != nil {
// We are already soliciting routers.
return
}
@@ -1833,14 +1858,27 @@ func (ndp *ndpState) startSolicitingRouters() {
delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
}
- ndp.rtrSolicitTimer = time.AfterFunc(delay, func() {
+ var done bool
+ ndp.rtrSolicit.done = &done
+ ndp.rtrSolicit.timer = time.AfterFunc(delay, func() {
+ ndp.nic.mu.Lock()
+ if done {
+ // If we reach this point, it means that the RS timer fired after another
+ // goroutine already obtained the NIC lock and stopped solicitations.
+ // Simply return here and do nothing further.
+ ndp.nic.mu.Unlock()
+ return
+ }
+
// As per RFC 4861 section 4.1, the source of the RS is an address assigned
// to the sending interface, or the unspecified address if no address is
// assigned to the sending interface.
- ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress)
+ ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress)
if ref == nil {
- ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
+ ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
}
+ ndp.nic.mu.Unlock()
+
localAddr := ref.ep.ID().LocalAddress
r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
@@ -1849,6 +1887,13 @@ func (ndp *ndpState) startSolicitingRouters() {
// header.IPv6AllRoutersMulticastAddress is a multicast address so a
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the NIC is not locked,
+ // the NIC could have been disabled or removed by another goroutine.
+ if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
+ return
+ }
+
panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err))
} else if c != nil {
panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()))
@@ -1893,17 +1938,18 @@ func (ndp *ndpState) startSolicitingRouters() {
}
ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
- if remaining == 0 {
- ndp.rtrSolicitTimer = nil
- } else if ndp.rtrSolicitTimer != nil {
+ if done || remaining == 0 {
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
+ } else if ndp.rtrSolicit.timer != nil {
// Note, we need to explicitly check to make sure that
// the timer field is not nil because if it was nil but
// we still reached this point, then we know the NIC
// was requested to stop soliciting routers so we don't
// need to send the next Router Solicitation message.
- ndp.rtrSolicitTimer.Reset(ndp.configs.RtrSolicitationInterval)
+ ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
}
+ ndp.nic.mu.Unlock()
})
}
@@ -1913,13 +1959,15 @@ func (ndp *ndpState) startSolicitingRouters() {
//
// The NIC ndp belongs to MUST be locked.
func (ndp *ndpState) stopSolicitingRouters() {
- if ndp.rtrSolicitTimer == nil {
+ if ndp.rtrSolicit.timer == nil {
// Nothing to do.
return
}
- ndp.rtrSolicitTimer.Stop()
- ndp.rtrSolicitTimer = nil
+ *ndp.rtrSolicit.done = true
+ ndp.rtrSolicit.timer.Stop()
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
}
// initializeTempAddrState initializes state related to temporary SLAAC
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index d756ae6f5..644c0d437 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -457,8 +457,20 @@ type ipv6AddrCandidate struct {
// remoteAddr must be a valid IPv6 address.
func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
n.mu.RLock()
- defer n.mu.RUnlock()
+ ref := n.primaryIPv6EndpointRLocked(remoteAddr)
+ n.mu.RUnlock()
+ return ref
+}
+// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address
+// Selection (RFC 6724 section 5).
+//
+// Note, only rules 1-3 and 7 are followed.
+//
+// remoteAddr must be a valid IPv6 address.
+//
+// n.mu MUST be read locked.
+func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber]
if len(primaryAddrs) == 0 {
@@ -568,11 +580,6 @@ const (
// promiscuous indicates that the NIC's promiscuous flag should be observed
// when getting a NIC's referenced network endpoint.
promiscuous
-
- // forceSpoofing indicates that the NIC should be assumed to be spoofing,
- // regardless of what the NIC's spoofing flag is when getting a NIC's
- // referenced network endpoint.
- forceSpoofing
)
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
@@ -591,8 +598,6 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
// or spoofing. Promiscuous mode will only be checked if promiscuous is true.
// Similarly, spoofing will only be checked if spoofing is true.
func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint {
- id := NetworkEndpointID{address}
-
n.mu.RLock()
var spoofingOrPromiscuous bool
@@ -601,11 +606,9 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
spoofingOrPromiscuous = n.mu.spoofing
case promiscuous:
spoofingOrPromiscuous = n.mu.promiscuous
- case forceSpoofing:
- spoofingOrPromiscuous = true
}
- if ref, ok := n.mu.endpoints[id]; ok {
+ if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
// An endpoint with this id exists, check if it can be used and return it.
switch ref.getKind() {
case permanentExpired:
@@ -654,11 +657,18 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
- if ref, ok := n.mu.endpoints[id]; ok {
+ ref := n.getRefOrCreateTempLocked(protocol, address, peb)
+ n.mu.Unlock()
+ return ref
+}
+
+/// getRefOrCreateTempLocked returns an existing endpoint for address or creates
+/// and returns a temporary endpoint.
+func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
+ if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
// No need to check the type as we are ok with expired endpoints at this
// point.
if ref.tryIncRef() {
- n.mu.Unlock()
return ref
}
// tryIncRef failing means the endpoint is scheduled to be removed once the
@@ -670,7 +680,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// Add a new temporary endpoint.
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
- n.mu.Unlock()
return nil
}
ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
@@ -681,7 +690,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
},
}, peb, temporary, static, false)
- n.mu.Unlock()
return ref
}
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index fea46158c..31f865260 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -15,11 +15,268 @@
package stack
import (
+ "math"
"testing"
+ "time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
+var _ LinkEndpoint = (*testLinkEndpoint)(nil)
+
+// A LinkEndpoint that throws away outgoing packets.
+//
+// We use this instead of the channel endpoint as the channel package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testLinkEndpoint struct {
+ dispatcher NetworkDispatcher
+}
+
+// Attach implements LinkEndpoint.Attach.
+func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (e *testLinkEndpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (*testLinkEndpoint) MTU() uint32 {
+ return math.MaxUint16
+}
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities {
+ return CapabilityResolutionRequired
+}
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*testLinkEndpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+// Wait implements LinkEndpoint.Wait.
+func (*testLinkEndpoint) Wait() {}
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // Our tests don't use this so we don't support it.
+ return 0, tcpip.ErrNotSupported
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
+ // Our tests don't use this so we don't support it.
+ return tcpip.ErrNotSupported
+}
+
+var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
+
+// An IPv6 NetworkEndpoint that throws away outgoing packets.
+//
+// We use this instead of ipv6.endpoint because the ipv6 package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testIPv6Endpoint struct {
+ nicID tcpip.NICID
+ id NetworkEndpointID
+ prefixLen int
+ linkEP LinkEndpoint
+ protocol *testIPv6Protocol
+}
+
+// DefaultTTL implements NetworkEndpoint.DefaultTTL.
+func (*testIPv6Endpoint) DefaultTTL() uint8 {
+ return 0
+}
+
+// MTU implements NetworkEndpoint.MTU.
+func (e *testIPv6Endpoint) MTU() uint32 {
+ return e.linkEP.MTU() - header.IPv6MinimumSize
+}
+
+// Capabilities implements NetworkEndpoint.Capabilities.
+func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
+func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+}
+
+// WritePacket implements NetworkEndpoint.WritePacket.
+func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+// WritePackets implements NetworkEndpoint.WritePackets.
+func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) {
+ // Our tests don't use this so we don't support it.
+ return 0, tcpip.ErrNotSupported
+}
+
+// WriteHeaderIncludedPacket implements
+// NetworkEndpoint.WriteHeaderIncludedPacket.
+func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error {
+ // Our tests don't use this so we don't support it.
+ return tcpip.ErrNotSupported
+}
+
+// ID implements NetworkEndpoint.ID.
+func (e *testIPv6Endpoint) ID() *NetworkEndpointID {
+ return &e.id
+}
+
+// PrefixLen implements NetworkEndpoint.PrefixLen.
+func (e *testIPv6Endpoint) PrefixLen() int {
+ return e.prefixLen
+}
+
+// NICID implements NetworkEndpoint.NICID.
+func (e *testIPv6Endpoint) NICID() tcpip.NICID {
+ return e.nicID
+}
+
+// HandlePacket implements NetworkEndpoint.HandlePacket.
+func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) {
+}
+
+// Close implements NetworkEndpoint.Close.
+func (*testIPv6Endpoint) Close() {}
+
+// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber.
+func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+var _ NetworkProtocol = (*testIPv6Protocol)(nil)
+
+// An IPv6 NetworkProtocol that supports the bare minimum to make a stack
+// believe it supports IPv6.
+//
+// We use this instead of ipv6.protocol because the ipv6 package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testIPv6Protocol struct{}
+
+// Number implements NetworkProtocol.Number.
+func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize.
+func (*testIPv6Protocol) MinimumPacketSize() int {
+ return header.IPv6MinimumSize
+}
+
+// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen.
+func (*testIPv6Protocol) DefaultPrefixLen() int {
+ return header.IPv6AddressSize * 8
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv6(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// NewEndpoint implements NetworkProtocol.NewEndpoint.
+func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
+ return &testIPv6Endpoint{
+ nicID: nicID,
+ id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
+ linkEP: linkEP,
+ protocol: p,
+ }, nil
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error {
+ return nil
+}
+
+// Option implements NetworkProtocol.Option.
+func (*testIPv6Protocol) Option(interface{}) *tcpip.Error {
+ return nil
+}
+
+// Close implements NetworkProtocol.Close.
+func (*testIPv6Protocol) Close() {}
+
+// Wait implements NetworkProtocol.Wait.
+func (*testIPv6Protocol) Wait() {}
+
+// Parse implements NetworkProtocol.Parse.
+func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ return 0, false, false
+}
+
+var _ LinkAddressResolver = (*testIPv6Protocol)(nil)
+
+// LinkAddressProtocol implements LinkAddressResolver.
+func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// LinkAddressRequest implements LinkAddressResolver.
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+ return nil
+}
+
+// ResolveStaticAddress implements LinkAddressResolver.
+func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if header.IsV6MulticastAddress(addr) {
+ return header.EthernetAddressFromMulticastIPv6Address(addr), true
+ }
+ return "", false
+}
+
+// Test the race condition where a NIC is removed and an RS timer fires at the
+// same time.
+func TestRemoveNICWhileHandlingRSTimer(t *testing.T) {
+ const (
+ nicID = 1
+
+ maxRtrSolicitations = 5
+ )
+
+ e := testLinkEndpoint{}
+ s := New(Options{
+ NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}},
+ NDPConfigs: NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: minimumRtrSolicitationInterval,
+ },
+ })
+
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ s.mu.Lock()
+ // Wait for the router solicitation timer to fire and block trying to obtain
+ // the stack lock when doing link address resolution.
+ time.Sleep(minimumRtrSolicitationInterval * 2)
+ if err := s.removeNICLocked(nicID); err != nil {
+ t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err)
+ }
+ s.mu.Unlock()
+}
+
func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
// When the NIC is disabled, the only field that matters is the stats field.
// This test is limited to stats counter checks.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index f5b6ca0b9..d65f8049e 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -113,6 +113,8 @@ func (r *Route) GSOMaxSize() uint32 {
// If address resolution is required, ErrNoLinkAddress and a notification channel is
// returned for the top level caller to block. Channel is closed once address resolution
// is complete (success or not).
+//
+// The NIC r uses must not be locked.
func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if !r.IsResolutionRequired() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
@@ -148,6 +150,8 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
// IsResolutionRequired returns true if Resolve() must be called to resolve
// the link address before the this route can be written to.
+//
+// The NIC r uses must not be locked.
func (r *Route) IsResolutionRequired() bool {
return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 294ce8775..648791302 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1017,6 +1017,13 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
+ return s.removeNICLocked(id)
+}
+
+// removeNICLocked removes NIC and all related routes from the network stack.
+//
+// s.mu must be locked.
+func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error {
nic, ok := s.nics[id]
if !ok {
return tcpip.ErrUnknownNICID