diff options
author | Chris Kuiper <ckuiper@google.com> | 2019-08-26 12:28:26 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2019-08-26 12:29:47 -0700 |
commit | ac2200b8a9c269926d2eb98a7c23be79b4738fcf (patch) | |
tree | 54ecc7a9a36ec1f65d8659de1f3cae1da1e5d64f /pkg/tcpip/stack/nic.go | |
parent | c9c52c024cf20c1c66327171af4287129724326e (diff) |
Prevent a network endpoint to send/rcv if its address was removed
This addresses the problem where an endpoint has its address removed but still
has outstanding references held by routes used in connected TCP/UDP sockets
which prevent the removal of the endpoint.
The fix adds a new "expired" flag to the referenced network endpoint, which is
set when an endpoint has its address removed. Incoming packets are not
delivered to an expired endpoint (unless in promiscuous mode), while sending
outgoing packets triggers an error to the caller (unless in spoofing mode).
In addition, a few helper functions were added to stack_test.go to reduce
code duplications.
PiperOrigin-RevId: 265514326
Diffstat (limited to 'pkg/tcpip/stack/nic.go')
-rw-r--r-- | pkg/tcpip/stack/nic.go | 193 |
1 files changed, 138 insertions, 55 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 4ef85bdfb..04b63d783 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -139,7 +139,7 @@ func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Add if list, ok := n.primary[protocol]; ok { for e := list.Front(); e != nil; e = e.Next() { ref := e.(*referencedNetworkEndpoint) - if ref.holdsInsertRef && ref.tryIncRef() { + if ref.kind == permanent && ref.tryIncRef() { r = ref break } @@ -178,7 +178,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN case header.IPv4Broadcast, header.IPv4Any: continue } - if r.tryIncRef() { + if r.isValidForOutgoing() && r.tryIncRef() { return r } } @@ -197,20 +197,32 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // getRefEpOrCreateTemp returns the referenced network endpoint for the given // protocol and address. If none exists a temporary one may be created if -// requested. -func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, allowTemp bool) *referencedNetworkEndpoint { +// we are in promiscuous mode or spoofing. +func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint { id := NetworkEndpointID{address} n.mu.RLock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.RUnlock() - return ref + if ref, ok := n.endpoints[id]; ok { + // An endpoint with this id exists, check if it can be used and return it. + switch ref.kind { + case permanentExpired: + if !spoofingOrPromiscuous { + n.mu.RUnlock() + return nil + } + fallthrough + case temporary, permanent: + if ref.tryIncRef() { + n.mu.RUnlock() + return ref + } + } } - // The address was not found, create a temporary one if requested by the - // caller or if the address is found in the NIC's subnets. - createTempEP := allowTemp + // A usable reference was not found, create a temporary one if requested by + // the caller or if the address is found in the NIC's subnets. + createTempEP := spoofingOrPromiscuous if !createTempEP { for _, sn := range n.subnets { if sn.Contains(address) { @@ -230,34 +242,68 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.Unlock() - return ref + if ref, ok := n.endpoints[id]; ok { + // No need to check the type as we are ok with expired endpoints at this + // point. + if ref.tryIncRef() { + n.mu.Unlock() + return ref + } + // tryIncRef failing means the endpoint is scheduled to be removed once the + // lock is released. Remove it here so we can create a new (temporary) one. + // The removal logic waiting for the lock handles this case. + n.removeEndpointLocked(ref) } + // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { n.mu.Unlock() return nil } - ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: address, PrefixLen: netProto.DefaultPrefixLen(), }, - }, peb, true) - - if ref != nil { - ref.holdsInsertRef = false - } + }, peb, temporary) n.mu.Unlock() return ref } -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) { + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} + if ref, ok := n.endpoints[id]; ok { + switch ref.kind { + case permanent: + // The NIC already have a permanent endpoint with that address. + return nil, tcpip.ErrDuplicateAddress + case permanentExpired, temporary: + // Promote the endpoint to become permanent. + if ref.tryIncRef() { + ref.kind = permanent + return ref, nil + } + // tryIncRef failing means the endpoint is scheduled to be removed once + // the lock is released. Remove it here so we can create a new + // (permanent) one. The removal logic waiting for the lock handles this + // case. + n.removeEndpointLocked(ref) + } + } + return n.addAddressLocked(protocolAddress, peb, permanent) +} + +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { + // Sanity check. + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} + if _, ok := n.endpoints[id]; ok { + // Endpoint already exists. + return nil, tcpip.ErrDuplicateAddress + } + netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol] if !ok { return nil, tcpip.ErrUnknownProtocol @@ -268,22 +314,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar if err != nil { return nil, err } - - id := *ep.ID() - if ref, ok := n.endpoints[id]; ok { - if !replace { - return nil, tcpip.ErrDuplicateAddress - } - - n.removeEndpointLocked(ref) - } - ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - holdsInsertRef: true, + refs: 1, + ep: ep, + nic: n, + protocol: protocolAddress.Protocol, + kind: kind, } // Set up cache if link address resolution exists for this protocol. @@ -316,7 +352,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { // Add the endpoint. n.mu.Lock() - _, err := n.addAddressLocked(protocolAddress, peb, false) + _, err := n.addPermanentAddressLocked(protocolAddress, peb) n.mu.Unlock() return err @@ -328,6 +364,12 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { defer n.mu.RUnlock() addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { + // Don't include expired or tempory endpoints to avoid confusion and + // prevent the caller from using those. + switch ref.kind { + case permanentExpired, temporary: + continue + } addrs = append(addrs, tcpip.ProtocolAddress{ Protocol: ref.protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ @@ -393,13 +435,16 @@ func (n *NIC) Subnets() []tcpip.Subnet { func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { id := *r.ep.ID() - // Nothing to do if the reference has already been replaced with a - // different one. + // Nothing to do if the reference has already been replaced with a different + // one. This happens in the case where 1) this endpoint's ref count hit zero + // and was waiting (on the lock) to be removed and 2) the same address was + // re-added in the meantime by removing this endpoint from the list and + // adding a new one. if n.endpoints[id] != r { return } - if r.holdsInsertRef { + if r.kind == permanent { panic("Reference count dropped to zero before being removed") } @@ -418,14 +463,13 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { n.mu.Unlock() } -func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { +func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { r := n.endpoints[NetworkEndpointID{addr}] - if r == nil || !r.holdsInsertRef { + if r == nil || r.kind != permanent { return tcpip.ErrBadLocalAddress } - r.holdsInsertRef = false - + r.kind = permanentExpired r.decRefLocked() return nil @@ -435,7 +479,7 @@ func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - return n.removeAddressLocked(addr) + return n.removePermanentAddressLocked(addr) } // joinGroup adds a new endpoint for the given multicast address, if none @@ -451,13 +495,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address if !ok { return tcpip.ErrUnknownProtocol } - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ + if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: addr, PrefixLen: netProto.DefaultPrefixLen(), }, - }, NeverPrimaryEndpoint, false); err != nil { + }, NeverPrimaryEndpoint); err != nil { return err } } @@ -479,7 +523,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress case 1: // This is the last one, clean up. - if err := n.removeAddressLocked(addr); err != nil { + if err := n.removePermanentAddressLocked(addr); err != nil { return err } } @@ -521,7 +565,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr // n.endpoints is mutex protected so acquire lock. n.mu.RLock() for _, ref := range n.endpoints { - if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { + if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remote ref.ep.HandlePacket(&r, vv) @@ -559,8 +603,9 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n := r.ref.nic n.mu.RLock() ref, ok := n.endpoints[NetworkEndpointID{dst}] + ok = ok && ref.isValidForOutgoing() && ref.tryIncRef() n.mu.RUnlock() - if ok && ref.tryIncRef() { + if ok { r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. ref.ep.HandlePacket(&r, vv) @@ -672,9 +717,33 @@ func (n *NIC) ID() tcpip.NICID { return n.id } +type networkEndpointKind int + +const ( + // A permanent endpoint is created by adding a permanent address (vs. a + // temporary one) to the NIC. Its reference count is biased by 1 to avoid + // removal when no route holds a reference to it. It is removed by explicitly + // removing the permanent address from the NIC. + permanent networkEndpointKind = iota + + // An expired permanent endoint is a permanent endoint that had its address + // removed from the NIC, and it is waiting to be removed once no more routes + // hold a reference to it. This is achieved by decreasing its reference count + // by 1. If its address is re-added before the endpoint is removed, its type + // changes back to permanent and its reference count increases by 1 again. + permanentExpired + + // A temporary endpoint is created for spoofing outgoing packets, or when in + // promiscuous mode and accepting incoming packets that don't match any + // permanent endpoint. Its reference count is not biased by 1 and the + // endpoint is removed immediately when no more route holds a reference to + // it. A temporary endpoint can be promoted to permanent if its address + // is added permanently. + temporary +) + type referencedNetworkEndpoint struct { ilist.Entry - refs int32 ep NetworkEndpoint nic *NIC protocol tcpip.NetworkProtocolNumber @@ -683,11 +752,25 @@ type referencedNetworkEndpoint struct { // protocol. Set to nil otherwise. linkCache LinkAddressCache - // holdsInsertRef is protected by the NIC's mutex. It indicates whether - // the reference count is biased by 1 due to the insertion of the - // endpoint. It is reset to false when RemoveAddress is called on the - // NIC. - holdsInsertRef bool + // refs is counting references held for this endpoint. When refs hits zero it + // triggers the automatic removal of the endpoint from the NIC. + refs int32 + + kind networkEndpointKind +} + +// isValidForOutgoing returns true if the endpoint can be used to send out a +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in spoofing mode. +func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { + return r.kind != permanentExpired || r.nic.spoofing +} + +// isValidForIncoming returns true if the endpoint can accept an incoming +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in promiscuous mode. +func (r *referencedNetworkEndpoint) isValidForIncoming() bool { + return r.kind != permanentExpired || r.nic.promiscuous } // decRef decrements the ref count and cleans up the endpoint once it reaches |