diff options
Diffstat (limited to 'pkg/tcpip/stack/nic.go')
-rw-r--r-- | pkg/tcpip/stack/nic.go | 280 |
1 files changed, 172 insertions, 108 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dc28dc970..04b63d783 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -139,7 +139,7 @@ func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Add if list, ok := n.primary[protocol]; ok { for e := list.Front(); e != nil; e = e.Next() { ref := e.(*referencedNetworkEndpoint) - if ref.holdsInsertRef && ref.tryIncRef() { + if ref.kind == permanent && ref.tryIncRef() { r = ref break } @@ -178,7 +178,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN case header.IPv4Broadcast, header.IPv4Any: continue } - if r.tryIncRef() { + if r.isValidForOutgoing() && r.tryIncRef() { return r } } @@ -186,46 +186,124 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN return nil } +func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { + return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous) +} + // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { + return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing) +} + +// getRefEpOrCreateTemp returns the referenced network endpoint for the given +// protocol and address. If none exists a temporary one may be created if +// we are in promiscuous mode or spoofing. +func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint { id := NetworkEndpointID{address} n.mu.RLock() - ref := n.endpoints[id] - if ref != nil && !ref.tryIncRef() { - ref = nil + + if ref, ok := n.endpoints[id]; ok { + // An endpoint with this id exists, check if it can be used and return it. + switch ref.kind { + case permanentExpired: + if !spoofingOrPromiscuous { + n.mu.RUnlock() + return nil + } + fallthrough + case temporary, permanent: + if ref.tryIncRef() { + n.mu.RUnlock() + return ref + } + } } - spoofing := n.spoofing + + // A usable reference was not found, create a temporary one if requested by + // the caller or if the address is found in the NIC's subnets. + createTempEP := spoofingOrPromiscuous + if !createTempEP { + for _, sn := range n.subnets { + if sn.Contains(address) { + createTempEP = true + break + } + } + } + n.mu.RUnlock() - if ref != nil || !spoofing { - return ref + if !createTempEP { + return nil } // Try again with the lock in exclusive mode. If we still can't get the // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - ref = n.endpoints[id] - if ref == nil || !ref.tryIncRef() { - if netProto, ok := n.stack.networkProtocols[protocol]; ok { - ref, _ = n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, peb, true) - if ref != nil { - ref.holdsInsertRef = false - } + if ref, ok := n.endpoints[id]; ok { + // No need to check the type as we are ok with expired endpoints at this + // point. + if ref.tryIncRef() { + n.mu.Unlock() + return ref } + // tryIncRef failing means the endpoint is scheduled to be removed once the + // lock is released. Remove it here so we can create a new (temporary) one. + // The removal logic waiting for the lock handles this case. + n.removeEndpointLocked(ref) } + + // Add a new temporary endpoint. + netProto, ok := n.stack.networkProtocols[protocol] + if !ok { + n.mu.Unlock() + return nil + } + ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: netProto.DefaultPrefixLen(), + }, + }, peb, temporary) + n.mu.Unlock() return ref } -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) { + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} + if ref, ok := n.endpoints[id]; ok { + switch ref.kind { + case permanent: + // The NIC already have a permanent endpoint with that address. + return nil, tcpip.ErrDuplicateAddress + case permanentExpired, temporary: + // Promote the endpoint to become permanent. + if ref.tryIncRef() { + ref.kind = permanent + return ref, nil + } + // tryIncRef failing means the endpoint is scheduled to be removed once + // the lock is released. Remove it here so we can create a new + // (permanent) one. The removal logic waiting for the lock handles this + // case. + n.removeEndpointLocked(ref) + } + } + return n.addAddressLocked(protocolAddress, peb, permanent) +} + +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { + // Sanity check. + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} + if _, ok := n.endpoints[id]; ok { + // Endpoint already exists. + return nil, tcpip.ErrDuplicateAddress + } + netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol] if !ok { return nil, tcpip.ErrUnknownProtocol @@ -236,22 +314,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar if err != nil { return nil, err } - - id := *ep.ID() - if ref, ok := n.endpoints[id]; ok { - if !replace { - return nil, tcpip.ErrDuplicateAddress - } - - n.removeEndpointLocked(ref) - } - ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - holdsInsertRef: true, + refs: 1, + ep: ep, + nic: n, + protocol: protocolAddress.Protocol, + kind: kind, } // Set up cache if link address resolution exists for this protocol. @@ -284,7 +352,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { // Add the endpoint. n.mu.Lock() - _, err := n.addAddressLocked(protocolAddress, peb, false) + _, err := n.addPermanentAddressLocked(protocolAddress, peb) n.mu.Unlock() return err @@ -296,6 +364,12 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { defer n.mu.RUnlock() addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { + // Don't include expired or tempory endpoints to avoid confusion and + // prevent the caller from using those. + switch ref.kind { + case permanentExpired, temporary: + continue + } addrs = append(addrs, tcpip.ProtocolAddress{ Protocol: ref.protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ @@ -361,13 +435,16 @@ func (n *NIC) Subnets() []tcpip.Subnet { func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { id := *r.ep.ID() - // Nothing to do if the reference has already been replaced with a - // different one. + // Nothing to do if the reference has already been replaced with a different + // one. This happens in the case where 1) this endpoint's ref count hit zero + // and was waiting (on the lock) to be removed and 2) the same address was + // re-added in the meantime by removing this endpoint from the list and + // adding a new one. if n.endpoints[id] != r { return } - if r.holdsInsertRef { + if r.kind == permanent { panic("Reference count dropped to zero before being removed") } @@ -386,14 +463,13 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { n.mu.Unlock() } -func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { +func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { r := n.endpoints[NetworkEndpointID{addr}] - if r == nil || !r.holdsInsertRef { + if r == nil || r.kind != permanent { return tcpip.ErrBadLocalAddress } - r.holdsInsertRef = false - + r.kind = permanentExpired r.decRefLocked() return nil @@ -403,7 +479,7 @@ func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - return n.removeAddressLocked(addr) + return n.removePermanentAddressLocked(addr) } // joinGroup adds a new endpoint for the given multicast address, if none @@ -419,13 +495,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address if !ok { return tcpip.ErrUnknownProtocol } - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ + if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: addr, PrefixLen: netProto.DefaultPrefixLen(), }, - }, NeverPrimaryEndpoint, false); err != nil { + }, NeverPrimaryEndpoint); err != nil { return err } } @@ -447,7 +523,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress case 1: // This is the last one, clean up. - if err := n.removeAddressLocked(addr); err != nil { + if err := n.removePermanentAddressLocked(addr); err != nil { return err } } @@ -489,7 +565,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr // n.endpoints is mutex protected so acquire lock. n.mu.RLock() for _, ref := range n.endpoints { - if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { + if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remote ref.ep.HandlePacket(&r, vv) @@ -527,8 +603,9 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n := r.ref.nic n.mu.RLock() ref, ok := n.endpoints[NetworkEndpointID{dst}] + ok = ok && ref.isValidForOutgoing() && ref.tryIncRef() n.mu.RUnlock() - if ok && ref.tryIncRef() { + if ok { r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. ref.ep.HandlePacket(&r, vv) @@ -553,57 +630,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n.stack.stats.IP.InvalidAddressesReceived.Increment() } -func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { - id := NetworkEndpointID{dst} - - n.mu.RLock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.RUnlock() - return ref - } - - promiscuous := n.promiscuous - // Check if the packet is for a subnet this NIC cares about. - if !promiscuous { - for _, sn := range n.subnets { - if sn.Contains(dst) { - promiscuous = true - break - } - } - } - n.mu.RUnlock() - if promiscuous { - // Try again with the lock in exclusive mode. If we still can't - // get the endpoint, create a new "temporary" one. It will only - // exist while there's a route through it. - n.mu.Lock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.Unlock() - return ref - } - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - n.mu.Unlock() - return nil - } - ref, err := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: dst, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, CanBePrimaryEndpoint, true) - n.mu.Unlock() - if err == nil { - ref.holdsInsertRef = false - return ref - } - } - - return nil -} - // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { @@ -691,9 +717,33 @@ func (n *NIC) ID() tcpip.NICID { return n.id } +type networkEndpointKind int + +const ( + // A permanent endpoint is created by adding a permanent address (vs. a + // temporary one) to the NIC. Its reference count is biased by 1 to avoid + // removal when no route holds a reference to it. It is removed by explicitly + // removing the permanent address from the NIC. + permanent networkEndpointKind = iota + + // An expired permanent endoint is a permanent endoint that had its address + // removed from the NIC, and it is waiting to be removed once no more routes + // hold a reference to it. This is achieved by decreasing its reference count + // by 1. If its address is re-added before the endpoint is removed, its type + // changes back to permanent and its reference count increases by 1 again. + permanentExpired + + // A temporary endpoint is created for spoofing outgoing packets, or when in + // promiscuous mode and accepting incoming packets that don't match any + // permanent endpoint. Its reference count is not biased by 1 and the + // endpoint is removed immediately when no more route holds a reference to + // it. A temporary endpoint can be promoted to permanent if its address + // is added permanently. + temporary +) + type referencedNetworkEndpoint struct { ilist.Entry - refs int32 ep NetworkEndpoint nic *NIC protocol tcpip.NetworkProtocolNumber @@ -702,11 +752,25 @@ type referencedNetworkEndpoint struct { // protocol. Set to nil otherwise. linkCache LinkAddressCache - // holdsInsertRef is protected by the NIC's mutex. It indicates whether - // the reference count is biased by 1 due to the insertion of the - // endpoint. It is reset to false when RemoveAddress is called on the - // NIC. - holdsInsertRef bool + // refs is counting references held for this endpoint. When refs hits zero it + // triggers the automatic removal of the endpoint from the NIC. + refs int32 + + kind networkEndpointKind +} + +// isValidForOutgoing returns true if the endpoint can be used to send out a +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in spoofing mode. +func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { + return r.kind != permanentExpired || r.nic.spoofing +} + +// isValidForIncoming returns true if the endpoint can accept an incoming +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in promiscuous mode. +func (r *referencedNetworkEndpoint) isValidForIncoming() bool { + return r.kind != permanentExpired || r.nic.promiscuous } // decRef decrements the ref count and cleans up the endpoint once it reaches |