diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2020-01-22 14:50:32 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-01-22 14:51:53 -0800 |
commit | 1d97adaa6d73dd897bb4e89d4533936a95003951 (patch) | |
tree | d5f6b80ca14f754a6628d442de636abb31d63d7f /pkg/tcpip/stack | |
parent | 8a5bfd70011752ebac93540e83354bab11c63735 (diff) |
Use embedded mutex pattern for stack.NIC
- Wrap NIC's fields that should only be accessed while holding the mutex in
an anonymous struct with the embedded mutex.
- Make sure NIC's spoofing and promiscuous mode flags are only read while
holding the NIC's mutex.
- Use the correct endpoint when sending DAD messages.
- Do not hold the NIC's lock when sending DAD messages.
This change does not introduce any behaviour changes.
Tests: Existing tests continue to pass.
PiperOrigin-RevId: 291036251
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/ndp.go | 181 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 45 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 236 |
3 files changed, 234 insertions, 228 deletions
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 7d4b41dfa..d983ac390 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -15,7 +15,6 @@ package stack import ( - "fmt" "log" "math/rand" "time" @@ -429,8 +428,13 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref return tcpip.ErrAddressFamilyNotSupported } - // Should not attempt to perform DAD on an address that is currently in - // the DAD process. + if ref.getKind() != permanentTentative { + // The endpoint should be marked as tentative since we are starting DAD. + log.Fatalf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()) + } + + // Should not attempt to perform DAD on an address that is currently in the + // DAD process. if _, ok := ndp.dad[addr]; ok { // Should never happen because we should only ever call this function for // newly created addresses. If we attemped to "add" an address that already @@ -438,77 +442,79 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // address, or its reference count would have been increased without doing // the work that would have been done for an address that was brand new. // See NIC.addAddressLocked. - panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) + log.Fatalf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()) } remaining := ndp.configs.DupAddrDetectTransmits - - { - done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref) - if err != nil { - return err - } - if done { - return nil - } + if remaining == 0 { + ref.setKind(permanent) + return nil } - remaining-- - var done bool var timer *time.Timer - timer = time.AfterFunc(ndp.configs.RetransmitTimer, func() { - var d bool - var err *tcpip.Error - - // doDadIteration does a single iteration of the DAD loop. - // - // Returns true if the integrator needs to be informed of DAD - // completing. - doDadIteration := func() bool { - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - - if done { - // If we reach this point, it means that the DAD - // timer fired after another goroutine already - // obtained the NIC lock and stopped DAD before - // this function obtained the NIC lock. Simply - // return here and do nothing further. - return false - } + // We initially start a timer to fire immediately because some of the DAD work + // cannot be done while holding the NIC's lock. This is effectively the same + // as starting a goroutine but we use a timer that fires immediately so we can + // reset it for the next DAD iteration. + timer = time.AfterFunc(0, func() { + ndp.nic.mu.RLock() + if done { + // If we reach this point, it means that the DAD timer fired after + // another goroutine already obtained the NIC lock and stopped DAD + // before this function obtained the NIC lock. Simply return here and do + // nothing further. + ndp.nic.mu.RUnlock() + return + } - ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}] - if !ok { - // This should never happen. - // We should have an endpoint for addr since we - // are still performing DAD on it. If the - // endpoint does not exist, but we are doing DAD - // on it, then we started DAD at some point, but - // forgot to stop it when the endpoint was - // deleted. - panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID())) - } + if ref.getKind() != permanentTentative { + // The endpoint should still be marked as tentative since we are still + // performing DAD on it. + log.Fatalf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID()) + } - d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref) - if err != nil || d { - delete(ndp.dad, addr) + dadDone := remaining == 0 + ndp.nic.mu.RUnlock() - if err != nil { - log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err) - } + var err *tcpip.Error + if !dadDone { + err = ndp.sendDADPacket(addr) + } - // Let the integrator know DAD has completed. - return true - } + ndp.nic.mu.Lock() + if done { + // If we reach this point, it means that DAD was stopped after we released + // the NIC's read lock and before we obtained the write lock. + ndp.nic.mu.Unlock() + return + } + if dadDone { + // DAD has resolved. + ref.setKind(permanent) + } else if err == nil { + // DAD is not done and we had no errors when sending the last NDP NS, + // schedule the next DAD timer. remaining-- timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) - return false + + ndp.nic.mu.Unlock() + return + } + + // At this point we know that either DAD is done or we hit an error sending + // the last NDP NS. Either way, clean up addr's DAD state and let the + // integrator know DAD has completed. + delete(ndp.dad, addr) + ndp.nic.mu.Unlock() + + if err != nil { + log.Printf("ndpdad: error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err) } - if doDadIteration() && ndp.nic.stack.ndpDisp != nil { - ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err) + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err) } }) @@ -520,45 +526,16 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref return nil } -// doDuplicateAddressDetection is called on every iteration of the timer, and -// when DAD starts. -// -// It handles resolving the address (if there are no more NS to send), or -// sending the next NS if there are more NS to send. -// -// This function must only be called by IPv6 addresses that are currently -// tentative. -// -// The NIC that ndp belongs to (n) MUST be locked. +// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns +// addr. // -// Returns true if DAD has resolved; false if DAD is still ongoing. -func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) { - if ref.getKind() != permanentTentative { - // The endpoint should still be marked as tentative - // since we are still performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) - } - - if remaining == 0 { - // DAD has resolved. - ref.setKind(permanent) - return true, nil - } - - // Send a new NS. +// addr must be a tentative IPv6 address on ndp's NIC. +func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}] - if !ok { - // This should never happen as if we have the - // address, we should have the solicited-node - // address. - panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr)) - } - snmcRef.incRef() - // Use the unspecified address as the source address when performing - // DAD. - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false) + // Use the unspecified address as the source address when performing DAD. + ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) + r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) @@ -569,15 +546,19 @@ func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining u pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}, tcpip.PacketBuffer{ - Header: hdr, - }); err != nil { + if err := r.WritePacket(nil, + NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + TOS: DefaultTOS, + }, tcpip.PacketBuffer{Header: hdr}, + ); err != nil { sent.Dropped.Increment() - return false, err + return err } sent.NeighborSolicit.Increment() - return false, nil + return nil } // stopDuplicateAddressDetection ends a running Duplicate Address Detection @@ -1212,7 +1193,7 @@ func (ndp *ndpState) startSolicitingRouters() { ndp.rtrSolicitTimer = time.AfterFunc(delay, func() { // Send an RS message with the unspecified source address. - ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, true) + ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 1a52e0e68..376681b30 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -301,6 +301,8 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s // Included in the subtests is a test to make sure that an invalid // RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. func TestDADResolve(t *testing.T) { + const nicID = 1 + tests := []struct { name string dupAddrDetectTransmits uint8 @@ -331,44 +333,36 @@ func TestDADResolve(t *testing.T) { opts.NDPConfigs.RetransmitTimer = test.retransTimer opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits - e := channel.New(10, 1280, linkAddr1) + e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1) s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) - } - - stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit - - // Should have sent an NDP NS immediately. - if got := stat.Value(); got != 1 { - t.Fatalf("got NeighborSolicit = %d, want = 1", got) - + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) } // Address should not be considered bound to the NIC yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } // Wait for the remaining time - some delta (500ms), to // make sure the address is still not resolved. const delta = 500 * time.Millisecond time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } // Wait for DAD to resolve. @@ -385,8 +379,8 @@ func TestDADResolve(t *testing.T) { if e.err != nil { t.Fatal("got DAD error: ", e.err) } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + if e.nicID != nicID { + t.Fatalf("got DAD event w/ nicID = %d, want = %d", e.nicID, nicID) } if e.addr != addr1 { t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) @@ -395,16 +389,16 @@ func TestDADResolve(t *testing.T) { t.Fatal("got DAD event w/ resolved = false, want = true") } } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) } // Should not have sent any more NS messages. - if got := stat.Value(); got != uint64(test.dupAddrDetectTransmits) { + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) } @@ -425,7 +419,6 @@ func TestDADResolve(t *testing.T) { } }) } - } // TestDADFail tests to make sure that the DAD process fails if another node is diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index de88c0bfa..79556a36f 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -35,24 +35,21 @@ type NIC struct { linkEP LinkEndpoint context NICContext - mu sync.RWMutex - spoofing bool - promiscuous bool - primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - addressRanges []tcpip.Subnet - mcastJoins map[NetworkEndpointID]int32 - // packetEPs is protected by mu, but the contained PacketEndpoint - // values are not. - packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint - stats NICStats - // ndp is the NDP related state for NIC. - // - // Note, read and write operations on ndp require that the NIC is - // appropriately locked. - ndp ndpState + mu struct { + sync.RWMutex + spoofing bool + promiscuous bool + primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint + endpoints map[NetworkEndpointID]*referencedNetworkEndpoint + addressRanges []tcpip.Subnet + mcastJoins map[NetworkEndpointID]int32 + // packetEPs is protected by mu, but the contained PacketEndpoint + // values are not. + packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint + ndp ndpState + } } // NICStats includes transmitted and received stats. @@ -97,15 +94,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // of IPv6 is supported on this endpoint's LinkEndpoint. nic := &NIC{ - stack: stack, - id: id, - name: name, - linkEP: ep, - context: ctx, - primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), - endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), - mcastJoins: make(map[NetworkEndpointID]int32), - packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint), + stack: stack, + id: id, + name: name, + linkEP: ep, + context: ctx, stats: NICStats{ Tx: DirectionStats{ Packets: &tcpip.StatCounter{}, @@ -116,22 +109,26 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC Bytes: &tcpip.StatCounter{}, }, }, - ndp: ndpState{ - configs: stack.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - autoGenAddresses: make(map[tcpip.Address]autoGenAddressState), - }, } - nic.ndp.nic = nic + nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint) + nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint) + nic.mu.mcastJoins = make(map[NetworkEndpointID]int32) + nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) + nic.mu.ndp = ndpState{ + nic: nic, + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + autoGenAddresses: make(map[tcpip.Address]autoGenAddressState), + } // Register supported packet endpoint protocols. for _, netProto := range header.Ethertypes { - nic.packetEPs[netProto] = []PacketEndpoint{} + nic.mu.packetEPs[netProto] = []PacketEndpoint{} } for _, netProto := range stack.networkProtocols { - nic.packetEPs[netProto.Number()] = []PacketEndpoint{} + nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{} } return nic @@ -215,7 +212,7 @@ func (n *NIC) enable() *tcpip.Error { // and default routers). Therefore, soliciting RAs from other routers on // a link is unnecessary for routers. if !n.stack.forwarding { - n.ndp.startSolicitingRouters() + n.mu.ndp.startSolicitingRouters() } return nil @@ -230,8 +227,8 @@ func (n *NIC) becomeIPv6Router() { n.mu.Lock() defer n.mu.Unlock() - n.ndp.cleanupHostOnlyState() - n.ndp.stopSolicitingRouters() + n.mu.ndp.cleanupHostOnlyState() + n.mu.ndp.stopSolicitingRouters() } // becomeIPv6Host transitions n into an IPv6 host. @@ -242,7 +239,7 @@ func (n *NIC) becomeIPv6Host() { n.mu.Lock() defer n.mu.Unlock() - n.ndp.startSolicitingRouters() + n.mu.ndp.startSolicitingRouters() } // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it @@ -254,13 +251,13 @@ func (n *NIC) attachLinkEndpoint() { // setPromiscuousMode enables or disables promiscuous mode. func (n *NIC) setPromiscuousMode(enable bool) { n.mu.Lock() - n.promiscuous = enable + n.mu.promiscuous = enable n.mu.Unlock() } func (n *NIC) isPromiscuousMode() bool { n.mu.RLock() - rv := n.promiscuous + rv := n.mu.promiscuous n.mu.RUnlock() return rv } @@ -272,7 +269,7 @@ func (n *NIC) isLoopback() bool { // setSpoofing enables or disables address spoofing. func (n *NIC) setSpoofing(enable bool) { n.mu.Lock() - n.spoofing = enable + n.mu.spoofing = enable n.mu.Unlock() } @@ -291,8 +288,8 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr t defer n.mu.RUnlock() var deprecatedEndpoint *referencedNetworkEndpoint - for _, r := range n.primary[protocol] { - if !r.isValidForOutgoing() { + for _, r := range n.mu.primary[protocol] { + if !r.isValidForOutgoingRLocked() { continue } @@ -342,7 +339,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn n.mu.RLock() defer n.mu.RUnlock() - primaryAddrs := n.primary[header.IPv6ProtocolNumber] + primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber] if len(primaryAddrs) == 0 { return nil @@ -425,7 +422,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn // hasPermanentAddrLocked returns true if n has a permanent (including currently // tentative) address, addr. func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { - ref, ok := n.endpoints[NetworkEndpointID{addr}] + ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return false @@ -436,24 +433,54 @@ func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { return kind == permanent || kind == permanentTentative } +type getRefBehaviour int + +const ( + // spoofing indicates that the NIC's spoofing flag should be observed when + // getting a NIC's referenced network endpoint. + spoofing getRefBehaviour = iota + + // promiscuous indicates that the NIC's promiscuous flag should be observed + // when getting a NIC's referenced network endpoint. + promiscuous + + // forceSpoofing indicates that the NIC should be assumed to be spoofing, + // regardless of what the NIC's spoofing flag is when getting a NIC's + // referenced network endpoint. + forceSpoofing +) + func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous) + return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing) + return n.getRefOrCreateTemp(protocol, address, peb, spoofing) } // getRefEpOrCreateTemp returns the referenced network endpoint for the given -// protocol and address. If none exists a temporary one may be created if -// we are in promiscuous mode or spoofing. -func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint { +// protocol and address. +// +// If none exists a temporary one may be created if we are in promiscuous mode +// or spoofing. Promiscuous mode will only be checked if promiscuous is true. +// Similarly, spoofing will only be checked if spoofing is true. +func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { id := NetworkEndpointID{address} n.mu.RLock() - if ref, ok := n.endpoints[id]; ok { + var spoofingOrPromiscuous bool + switch tempRef { + case spoofing: + spoofingOrPromiscuous = n.mu.spoofing + case promiscuous: + spoofingOrPromiscuous = n.mu.promiscuous + case forceSpoofing: + spoofingOrPromiscuous = true + } + + if ref, ok := n.mu.endpoints[id]; ok { // An endpoint with this id exists, check if it can be used and return it. switch ref.getKind() { case permanentExpired: @@ -474,7 +501,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // the caller or if the address is found in the NIC's subnets. createTempEP := spoofingOrPromiscuous if !createTempEP { - for _, sn := range n.addressRanges { + for _, sn := range n.mu.addressRanges { // Skip the subnet address. if address == sn.ID() { continue @@ -502,7 +529,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - if ref, ok := n.endpoints[id]; ok { + if ref, ok := n.mu.endpoints[id]; ok { // No need to check the type as we are ok with expired endpoints at this // point. if ref.tryIncRef() { @@ -543,7 +570,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar // Sanity check. id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address} - if ref, ok := n.endpoints[id]; ok { + if ref, ok := n.mu.endpoints[id]; ok { // Endpoint already exists. if kind != permanent { return nil, tcpip.ErrDuplicateAddress @@ -562,7 +589,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar ref.deprecated = deprecated ref.configType = configType - refs := n.primary[ref.protocol] + refs := n.mu.primary[ref.protocol] for i, r := range refs { if r == ref { switch peb { @@ -572,9 +599,9 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar if i == 0 { return ref, nil } - n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) case NeverPrimaryEndpoint: - n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) return ref, nil } } @@ -637,13 +664,13 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } } - n.endpoints[id] = ref + n.mu.endpoints[id] = ref n.insertPrimaryEndpointLocked(ref, peb) // If we are adding a tentative IPv6 address, start DAD. if isIPv6Unicast && kind == permanentTentative { - if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { + if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { return nil, err } } @@ -668,8 +695,8 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { n.mu.RLock() defer n.mu.RUnlock() - addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) - for nid, ref := range n.endpoints { + addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints)) + for nid, ref := range n.mu.endpoints { // Don't include tentative, expired or temporary endpoints to // avoid confusion and prevent the caller from using those. switch ref.getKind() { @@ -695,7 +722,7 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { defer n.mu.RUnlock() var addrs []tcpip.ProtocolAddress - for proto, list := range n.primary { + for proto, list := range n.mu.primary { for _, ref := range list { // Don't include tentative, expired or tempory endpoints // to avoid confusion and prevent the caller from using @@ -726,7 +753,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit n.mu.RLock() defer n.mu.RUnlock() - list, ok := n.primary[proto] + list, ok := n.mu.primary[proto] if !ok { return tcpip.AddressWithPrefix{} } @@ -769,7 +796,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit // address. func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { n.mu.Lock() - n.addressRanges = append(n.addressRanges, subnet) + n.mu.addressRanges = append(n.mu.addressRanges, subnet) n.mu.Unlock() } @@ -778,13 +805,13 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Lock() // Use the same underlying array. - tmp := n.addressRanges[:0] - for _, sub := range n.addressRanges { + tmp := n.mu.addressRanges[:0] + for _, sub := range n.mu.addressRanges { if sub != subnet { tmp = append(tmp, sub) } } - n.addressRanges = tmp + n.mu.addressRanges = tmp n.mu.Unlock() } @@ -793,8 +820,8 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() - sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints)) - for nid := range n.endpoints { + sns := make([]tcpip.Subnet, 0, len(n.mu.addressRanges)+len(n.mu.endpoints)) + for nid := range n.mu.endpoints { sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) if err != nil { // This should never happen as the mask has been carefully crafted to @@ -803,7 +830,7 @@ func (n *NIC) AddressRanges() []tcpip.Subnet { } sns = append(sns, sn) } - return append(sns, n.addressRanges...) + return append(sns, n.mu.addressRanges...) } // insertPrimaryEndpointLocked adds r to n's primary endpoint list as required @@ -813,9 +840,9 @@ func (n *NIC) AddressRanges() []tcpip.Subnet { func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) { switch peb { case CanBePrimaryEndpoint: - n.primary[r.protocol] = append(n.primary[r.protocol], r) + n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r) case FirstPrimaryEndpoint: - n.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.primary[r.protocol]...) + n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...) } } @@ -827,7 +854,7 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { // and was waiting (on the lock) to be removed and 2) the same address was // re-added in the meantime by removing this endpoint from the list and // adding a new one. - if n.endpoints[id] != r { + if n.mu.endpoints[id] != r { return } @@ -835,11 +862,11 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { panic("Reference count dropped to zero before being removed") } - delete(n.endpoints, id) - refs := n.primary[r.protocol] + delete(n.mu.endpoints, id) + refs := n.mu.primary[r.protocol] for i, ref := range refs { if ref == r { - n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) break } } @@ -854,7 +881,7 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { } func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r, ok := n.endpoints[NetworkEndpointID{addr}] + r, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return tcpip.ErrBadLocalAddress } @@ -870,13 +897,13 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { // If we are removing a tentative IPv6 unicast address, stop // DAD. if kind == permanentTentative { - n.ndp.stopDuplicateAddressDetection(addr) + n.mu.ndp.stopDuplicateAddressDetection(addr) } // If we are removing an address generated via SLAAC, cleanup // its SLAAC resources and notify the integrator. if r.configType == slaac { - n.ndp.cleanupAutoGenAddrResourcesAndNotify(addr) + n.mu.ndp.cleanupAutoGenAddrResourcesAndNotify(addr) } } @@ -926,7 +953,7 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A // outlined in RFC 3810 section 5. id := NetworkEndpointID{addr} - joins := n.mcastJoins[id] + joins := n.mu.mcastJoins[id] if joins == 0 { netProto, ok := n.stack.networkProtocols[protocol] if !ok { @@ -942,7 +969,7 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A return err } } - n.mcastJoins[id] = joins + 1 + n.mu.mcastJoins[id] = joins + 1 return nil } @@ -960,7 +987,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { // before leaveGroupLocked is called. func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} - joins := n.mcastJoins[id] + joins := n.mu.mcastJoins[id] switch joins { case 0: // There are no joins with this address on this NIC. @@ -971,7 +998,7 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { return err } } - n.mcastJoins[id] = joins - 1 + n.mu.mcastJoins[id] = joins - 1 return nil } @@ -1006,12 +1033,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // Are any packet sockets listening for this network protocol? n.mu.RLock() - packetEPs := n.packetEPs[protocol] + packetEPs := n.mu.packetEPs[protocol] // Check whether there are packet sockets listening for every protocol. // If we received a packet with protocol EthernetProtocolAll, then the // previous for loop will have handled it. if protocol != header.EthernetProtocolAll { - packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...) + packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) } n.mu.RUnlock() for _, ep := range packetEPs { @@ -1060,8 +1087,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // Found a NIC. n := r.ref.nic n.mu.RLock() - ref, ok := n.endpoints[NetworkEndpointID{dst}] - ok = ok && ref.isValidForOutgoing() && ref.tryIncRef() + ref, ok := n.mu.endpoints[NetworkEndpointID{dst}] + ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef() n.mu.RUnlock() if ok { r.RemoteAddress = src @@ -1181,7 +1208,7 @@ func (n *NIC) Stack() *Stack { // false. It will only return true if the address is associated with the NIC // AND it is tentative. func (n *NIC) isAddrTentative(addr tcpip.Address) bool { - ref, ok := n.endpoints[NetworkEndpointID{addr}] + ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return false } @@ -1197,7 +1224,7 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - ref, ok := n.endpoints[NetworkEndpointID{addr}] + ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return tcpip.ErrBadAddress } @@ -1217,7 +1244,7 @@ func (n *NIC) setNDPConfigs(c NDPConfigurations) { c.validate() n.mu.Lock() - n.ndp.configs = c + n.mu.ndp.configs = c n.mu.Unlock() } @@ -1226,7 +1253,7 @@ func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { n.mu.Lock() defer n.mu.Unlock() - n.ndp.handleRA(ip, ra) + n.mu.ndp.handleRA(ip, ra) } type networkEndpointKind int32 @@ -1268,11 +1295,11 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa n.mu.Lock() defer n.mu.Unlock() - eps, ok := n.packetEPs[netProto] + eps, ok := n.mu.packetEPs[netProto] if !ok { return tcpip.ErrNotSupported } - n.packetEPs[netProto] = append(eps, ep) + n.mu.packetEPs[netProto] = append(eps, ep) return nil } @@ -1281,14 +1308,14 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep n.mu.Lock() defer n.mu.Unlock() - eps, ok := n.packetEPs[netProto] + eps, ok := n.mu.packetEPs[netProto] if !ok { return } for i, epOther := range eps { if epOther == ep { - n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) + n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) return } } @@ -1346,14 +1373,19 @@ func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { // packet. It requires the endpoint to not be marked expired (i.e., its address // has been removed), or the NIC to be in spoofing mode. func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { - return r.getKind() != permanentExpired || r.nic.spoofing + r.nic.mu.RLock() + defer r.nic.mu.RUnlock() + + return r.isValidForOutgoingRLocked() } -// isValidForIncoming returns true if the endpoint can accept an incoming -// packet. It requires the endpoint to not be marked expired (i.e., its address -// has been removed), or the NIC to be in promiscuous mode. -func (r *referencedNetworkEndpoint) isValidForIncoming() bool { - return r.getKind() != permanentExpired || r.nic.promiscuous +// isValidForOutgoingRLocked returns true if the endpoint can be used to send +// out a packet. It requires the endpoint to not be marked expired (i.e., its +// address has been removed), or the NIC to be in spoofing mode. +// +// r's NIC must be read locked. +func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { + return r.getKind() != permanentExpired || r.nic.mu.spoofing } // decRef decrements the ref count and cleans up the endpoint once it reaches |