From afbdf2f212739880e70a5450a9292e3265caecba Mon Sep 17 00:00:00 2001 From: Chris Kuiper Date: Fri, 30 Aug 2019 17:18:05 -0700 Subject: Fix data race accessing referencedNetworkEndpoint.kind Wrapping "kind" into atomic access functions. Fixes #789 PiperOrigin-RevId: 266485501 --- pkg/tcpip/stack/nic.go | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) (limited to 'pkg/tcpip/stack/nic.go') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 89b4c5960..f947b55db 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -139,7 +139,7 @@ func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Add if list, ok := n.primary[protocol]; ok { for e := list.Front(); e != nil; e = e.Next() { ref := e.(*referencedNetworkEndpoint) - if ref.kind == permanent && ref.tryIncRef() { + if ref.getKind() == permanent && ref.tryIncRef() { r = ref break } @@ -205,7 +205,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t if ref, ok := n.endpoints[id]; ok { // An endpoint with this id exists, check if it can be used and return it. - switch ref.kind { + switch ref.getKind() { case permanentExpired: if !spoofingOrPromiscuous { n.mu.RUnlock() @@ -276,14 +276,14 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) { id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} if ref, ok := n.endpoints[id]; ok { - switch ref.kind { + switch ref.getKind() { case permanent: // The NIC already have a permanent endpoint with that address. return nil, tcpip.ErrDuplicateAddress case permanentExpired, temporary: // Promote the endpoint to become permanent. if ref.tryIncRef() { - ref.kind = permanent + ref.setKind(permanent) return ref, nil } // tryIncRef failing means the endpoint is scheduled to be removed once @@ -366,7 +366,7 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { for nid, ref := range n.endpoints { // Don't include expired or tempory endpoints to avoid confusion and // prevent the caller from using those. - switch ref.kind { + switch ref.getKind() { case permanentExpired, temporary: continue } @@ -444,7 +444,7 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { return } - if r.kind == permanent { + if r.getKind() == permanent { panic("Reference count dropped to zero before being removed") } @@ -465,11 +465,11 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { r := n.endpoints[NetworkEndpointID{addr}] - if r == nil || r.kind != permanent { + if r == nil || r.getKind() != permanent { return tcpip.ErrBadLocalAddress } - r.kind = permanentExpired + r.setKind(permanentExpired) r.decRefLocked() return nil @@ -720,7 +720,7 @@ func (n *NIC) ID() tcpip.NICID { return n.id } -type networkEndpointKind int +type networkEndpointKind int32 const ( // A permanent endpoint is created by adding a permanent address (vs. a @@ -759,21 +759,30 @@ type referencedNetworkEndpoint struct { // triggers the automatic removal of the endpoint from the NIC. refs int32 + // networkEndpointKind must only be accessed using {get,set}Kind(). kind networkEndpointKind } +func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { + return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind))) +} + +func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { + atomic.StoreInt32((*int32)(&r.kind), int32(kind)) +} + // isValidForOutgoing returns true if the endpoint can be used to send out a // packet. It requires the endpoint to not be marked expired (i.e., its address // has been removed), or the NIC to be in spoofing mode. func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { - return r.kind != permanentExpired || r.nic.spoofing + return r.getKind() != permanentExpired || r.nic.spoofing } // isValidForIncoming returns true if the endpoint can accept an incoming // packet. It requires the endpoint to not be marked expired (i.e., its address // has been removed), or the NIC to be in promiscuous mode. func (r *referencedNetworkEndpoint) isValidForIncoming() bool { - return r.kind != permanentExpired || r.nic.promiscuous + return r.getKind() != permanentExpired || r.nic.promiscuous } // decRef decrements the ref count and cleans up the endpoint once it reaches -- cgit v1.2.3