diff options
Diffstat (limited to 'pkg/tcpip/stack/nic.go')
-rw-r--r-- | pkg/tcpip/stack/nic.go | 460 |
1 files changed, 306 insertions, 154 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 4ef85bdfb..5993fe582 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -34,15 +34,13 @@ type NIC struct { linkEP LinkEndpoint loopback bool - demux *transportDemuxer - - mu sync.RWMutex - spoofing bool - promiscuous bool - primary map[tcpip.NetworkProtocolNumber]*ilist.List - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - subnets []tcpip.Subnet - mcastJoins map[NetworkEndpointID]int32 + mu sync.RWMutex + spoofing bool + promiscuous bool + primary map[tcpip.NetworkProtocolNumber]*ilist.List + endpoints map[NetworkEndpointID]*referencedNetworkEndpoint + addressRanges []tcpip.Subnet + mcastJoins map[NetworkEndpointID]int32 stats NICStats } @@ -85,7 +83,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback name: name, linkEP: ep, loopback: loopback, - demux: newTransportDemuxer(stack), primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), @@ -102,6 +99,35 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback } } +// enable enables the NIC. enable will attach the link to its LinkEndpoint and +// join the IPv6 All-Nodes Multicast address (ff02::1). +func (n *NIC) enable() *tcpip.Error { + n.attachLinkEndpoint() + + // Create an endpoint to receive broadcast packets on this interface. + if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + if err := n.AddAddress(tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}, + }, NeverPrimaryEndpoint); err != nil { + return err + } + } + + // Join the IPv6 All-Nodes Multicast group if the stack is configured to + // use IPv6. This is required to ensure that this node properly receives + // and responds to the various NDP messages that are destined to the + // all-nodes multicast address. An example is the Neighbor Advertisement + // when we perform Duplicate Address Detection, or Router Advertisement + // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 + // section 4.2 for more information. + if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { + return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress) + } + + return nil +} + // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it // to start delivering packets. func (n *NIC) attachLinkEndpoint() { @@ -129,37 +155,6 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { - n.mu.RLock() - defer n.mu.RUnlock() - - var r *referencedNetworkEndpoint - - // Check for a primary endpoint. - if list, ok := n.primary[protocol]; ok { - for e := list.Front(); e != nil; e = e.Next() { - ref := e.(*referencedNetworkEndpoint) - if ref.holdsInsertRef && ref.tryIncRef() { - r = ref - break - } - } - - } - - if r == nil { - return tcpip.AddressWithPrefix{}, tcpip.ErrNoLinkAddress - } - - addressWithPrefix := tcpip.AddressWithPrefix{ - Address: r.ep.ID().LocalAddress, - PrefixLen: r.ep.PrefixLen(), - } - r.decRef() - - return addressWithPrefix, nil -} - // primaryEndpoint returns the primary endpoint of n for the given network // protocol. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { @@ -178,7 +173,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN case header.IPv4Broadcast, header.IPv4Any: continue } - if r.tryIncRef() { + if r.isValidForOutgoing() && r.tryIncRef() { return r } } @@ -197,22 +192,44 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // getRefEpOrCreateTemp returns the referenced network endpoint for the given // protocol and address. If none exists a temporary one may be created if -// requested. -func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, allowTemp bool) *referencedNetworkEndpoint { +// we are in promiscuous mode or spoofing. +func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint { id := NetworkEndpointID{address} n.mu.RLock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.RUnlock() - return ref + if ref, ok := n.endpoints[id]; ok { + // An endpoint with this id exists, check if it can be used and return it. + switch ref.getKind() { + case permanentExpired: + if !spoofingOrPromiscuous { + n.mu.RUnlock() + return nil + } + fallthrough + case temporary, permanent: + if ref.tryIncRef() { + n.mu.RUnlock() + return ref + } + } } - // The address was not found, create a temporary one if requested by the - // caller or if the address is found in the NIC's subnets. - createTempEP := allowTemp + // A usable reference was not found, create a temporary one if requested by + // the caller or if the address is found in the NIC's subnets. + createTempEP := spoofingOrPromiscuous if !createTempEP { - for _, sn := range n.subnets { + for _, sn := range n.addressRanges { + // Skip the subnet address. + if address == sn.ID() { + continue + } + // For now just skip the broadcast address, until we support it. + // FIXME(b/137608825): Add support for sending/receiving directed + // (subnet) broadcast. + if address == sn.Broadcast() { + continue + } if sn.Contains(address) { createTempEP = true break @@ -230,34 +247,70 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.Unlock() - return ref + if ref, ok := n.endpoints[id]; ok { + // No need to check the type as we are ok with expired endpoints at this + // point. + if ref.tryIncRef() { + n.mu.Unlock() + return ref + } + // tryIncRef failing means the endpoint is scheduled to be removed once the + // lock is released. Remove it here so we can create a new (temporary) one. + // The removal logic waiting for the lock handles this case. + n.removeEndpointLocked(ref) } + // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { n.mu.Unlock() return nil } - ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: address, PrefixLen: netProto.DefaultPrefixLen(), }, - }, peb, true) - - if ref != nil { - ref.holdsInsertRef = false - } + }, peb, temporary) n.mu.Unlock() return ref } -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) { + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} + if ref, ok := n.endpoints[id]; ok { + switch ref.getKind() { + case permanent: + // The NIC already have a permanent endpoint with that address. + return nil, tcpip.ErrDuplicateAddress + case permanentExpired, temporary: + // Promote the endpoint to become permanent. + if ref.tryIncRef() { + ref.setKind(permanent) + return ref, nil + } + // tryIncRef failing means the endpoint is scheduled to be removed once + // the lock is released. Remove it here so we can create a new + // (permanent) one. The removal logic waiting for the lock handles this + // case. + n.removeEndpointLocked(ref) + } + } + return n.addAddressLocked(protocolAddress, peb, permanent) +} + +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { + // TODO(b/141022673): Validate IP address before adding them. + + // Sanity check. + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} + if _, ok := n.endpoints[id]; ok { + // Endpoint already exists. + return nil, tcpip.ErrDuplicateAddress + } + netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol] if !ok { return nil, tcpip.ErrUnknownProtocol @@ -268,22 +321,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar if err != nil { return nil, err } - - id := *ep.ID() - if ref, ok := n.endpoints[id]; ok { - if !replace { - return nil, tcpip.ErrDuplicateAddress - } - - n.removeEndpointLocked(ref) - } - ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - holdsInsertRef: true, + refs: 1, + ep: ep, + nic: n, + protocol: protocolAddress.Protocol, + kind: kind, } // Set up cache if link address resolution exists for this protocol. @@ -293,6 +336,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } } + // If we are adding an IPv6 unicast address, join the solicited-node + // multicast address. + if protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) { + snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) + if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { + return nil, err + } + } + n.endpoints[id] = ref l, ok := n.primary[protocolAddress.Protocol] @@ -316,18 +368,26 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { // Add the endpoint. n.mu.Lock() - _, err := n.addAddressLocked(protocolAddress, peb, false) + _, err := n.addPermanentAddressLocked(protocolAddress, peb) n.mu.Unlock() return err } -// Addresses returns the addresses associated with this NIC. -func (n *NIC) Addresses() []tcpip.ProtocolAddress { +// AllAddresses returns all addresses (primary and non-primary) associated with +// this NIC. +func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { n.mu.RLock() defer n.mu.RUnlock() + addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { + // Don't include expired or temporary endpoints to avoid confusion and + // prevent the caller from using those. + switch ref.getKind() { + case permanentExpired, temporary: + continue + } addrs = append(addrs, tcpip.ProtocolAddress{ Protocol: ref.protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ @@ -339,45 +399,66 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { return addrs } -// AddSubnet adds a new subnet to n, so that it starts accepting packets -// targeted at the given address and network protocol. -func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { +// PrimaryAddresses returns the primary addresses associated with this NIC. +func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { + n.mu.RLock() + defer n.mu.RUnlock() + + var addrs []tcpip.ProtocolAddress + for proto, list := range n.primary { + for e := list.Front(); e != nil; e = e.Next() { + ref := e.(*referencedNetworkEndpoint) + // Don't include expired or tempory endpoints to avoid confusion and + // prevent the caller from using those. + switch ref.getKind() { + case permanentExpired, temporary: + continue + } + + addrs = append(addrs, tcpip.ProtocolAddress{ + Protocol: proto, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: ref.ep.ID().LocalAddress, + PrefixLen: ref.ep.PrefixLen(), + }, + }) + } + } + return addrs +} + +// AddAddressRange adds a range of addresses to n, so that it starts accepting +// packets targeted at the given addresses and network protocol. The range is +// given by a subnet address, and all addresses contained in the subnet are +// used except for the subnet address itself and the subnet's broadcast +// address. +func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { n.mu.Lock() - n.subnets = append(n.subnets, subnet) + n.addressRanges = append(n.addressRanges, subnet) n.mu.Unlock() } -// RemoveSubnet removes the given subnet from n. -func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) { +// RemoveAddressRange removes the given address range from n. +func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Lock() // Use the same underlying array. - tmp := n.subnets[:0] - for _, sub := range n.subnets { + tmp := n.addressRanges[:0] + for _, sub := range n.addressRanges { if sub != subnet { tmp = append(tmp, sub) } } - n.subnets = tmp + n.addressRanges = tmp n.mu.Unlock() } -// ContainsSubnet reports whether this NIC contains the given subnet. -func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool { - for _, s := range n.Subnets() { - if s == subnet { - return true - } - } - return false -} - // Subnets returns the Subnets associated with this NIC. -func (n *NIC) Subnets() []tcpip.Subnet { +func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() - sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints)) + sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints)) for nid := range n.endpoints { sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) if err != nil { @@ -387,19 +468,22 @@ func (n *NIC) Subnets() []tcpip.Subnet { } sns = append(sns, sn) } - return append(sns, n.subnets...) + return append(sns, n.addressRanges...) } func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { id := *r.ep.ID() - // Nothing to do if the reference has already been replaced with a - // different one. + // Nothing to do if the reference has already been replaced with a different + // one. This happens in the case where 1) this endpoint's ref count hit zero + // and was waiting (on the lock) to be removed and 2) the same address was + // re-added in the meantime by removing this endpoint from the list and + // adding a new one. if n.endpoints[id] != r { return } - if r.holdsInsertRef { + if r.getKind() == permanent { panic("Reference count dropped to zero before being removed") } @@ -418,15 +502,28 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { n.mu.Unlock() } -func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { - r := n.endpoints[NetworkEndpointID{addr}] - if r == nil || !r.holdsInsertRef { +func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { + r, ok := n.endpoints[NetworkEndpointID{addr}] + if !ok || r.getKind() != permanent { return tcpip.ErrBadLocalAddress } - r.holdsInsertRef = false + r.setKind(permanentExpired) + if !r.decRefLocked() { + // The endpoint still has references to it. + return nil + } + + // At this point the endpoint is deleted. - r.decRefLocked() + // If we are removing an IPv6 unicast address, leave the solicited-node + // multicast address. + if r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) { + snmc := header.SolicitedNodeAddr(addr) + if err := n.leaveGroupLocked(snmc); err != nil { + return err + } + } return nil } @@ -435,7 +532,7 @@ func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - return n.removeAddressLocked(addr) + return n.removePermanentAddressLocked(addr) } // joinGroup adds a new endpoint for the given multicast address, if none @@ -444,6 +541,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address n.mu.Lock() defer n.mu.Unlock() + return n.joinGroupLocked(protocol, addr) +} + +// joinGroupLocked adds a new endpoint for the given multicast address, if none +// exists yet. Otherwise it just increments its count. n MUST be locked before +// joinGroupLocked is called. +func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] if joins == 0 { @@ -451,13 +555,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address if !ok { return tcpip.ErrUnknownProtocol } - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ + if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: addr, PrefixLen: netProto.DefaultPrefixLen(), }, - }, NeverPrimaryEndpoint, false); err != nil { + }, NeverPrimaryEndpoint); err != nil { return err } } @@ -471,6 +575,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() + return n.leaveGroupLocked(addr) +} + +// leaveGroupLocked decrements the count for the given multicast address, and +// when it reaches zero removes the endpoint for this address. n MUST be locked +// before leaveGroupLocked is called. +func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] switch joins { @@ -479,7 +590,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress case 1: // This is the last one, clean up. - if err := n.removeAddressLocked(addr); err != nil { + if err := n.removePermanentAddressLocked(addr); err != nil { return err } } @@ -487,6 +598,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return nil } +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) { + r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) + r.RemoteLinkAddress = remotelinkAddr + ref.ep.HandlePacket(&r, vv) + ref.decRef() +} + // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when // the NIC receives a packet from the physical interface. @@ -514,29 +632,10 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr src, dst := netProto.ParseAddresses(vv.First()) - // If the packet is destined to the IPv4 Broadcast address, then make a - // route to each IPv4 network endpoint and let each endpoint handle the - // packet. - if dst == header.IPv4Broadcast { - // n.endpoints is mutex protected so acquire lock. - n.mu.RLock() - for _, ref := range n.endpoints { - if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) - r.RemoteLinkAddress = remote - ref.ep.HandlePacket(&r, vv) - ref.decRef() - } - } - n.mu.RUnlock() - return - } + n.stack.AddLinkAddress(n.id, src, remote) if ref := n.getRef(protocol, dst); ref != nil { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) - r.RemoteLinkAddress = remote - ref.ep.HandlePacket(&r, vv) - ref.decRef() + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) return } @@ -559,8 +658,9 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n := r.ref.nic n.mu.RLock() ref, ok := n.endpoints[NetworkEndpointID{dst}] + ok = ok && ref.isValidForOutgoing() && ref.tryIncRef() n.mu.RUnlock() - if ok && ref.tryIncRef() { + if ok { r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. ref.ep.HandlePacket(&r, vv) @@ -599,9 +699,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) { - n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) - } + n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) if len(vv.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() @@ -615,9 +713,6 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.demux.deliverPacket(r, protocol, netHeader, vv, id) { - return - } if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { return } @@ -631,7 +726,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // We could not find an appropriate destination for this packet, so // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, vv) { + if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) { n.stack.stats.MalformedRcvdPackets.Increment() } } @@ -659,10 +754,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { - return - } - if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { + if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) { return } } @@ -672,9 +764,38 @@ func (n *NIC) ID() tcpip.NICID { return n.id } +// Stack returns the instance of the Stack that owns this NIC. +func (n *NIC) Stack() *Stack { + return n.stack +} + +type networkEndpointKind int32 + +const ( + // A permanent endpoint is created by adding a permanent address (vs. a + // temporary one) to the NIC. Its reference count is biased by 1 to avoid + // removal when no route holds a reference to it. It is removed by explicitly + // removing the permanent address from the NIC. + permanent networkEndpointKind = iota + + // An expired permanent endoint is a permanent endoint that had its address + // removed from the NIC, and it is waiting to be removed once no more routes + // hold a reference to it. This is achieved by decreasing its reference count + // by 1. If its address is re-added before the endpoint is removed, its type + // changes back to permanent and its reference count increases by 1 again. + permanentExpired + + // A temporary endpoint is created for spoofing outgoing packets, or when in + // promiscuous mode and accepting incoming packets that don't match any + // permanent endpoint. Its reference count is not biased by 1 and the + // endpoint is removed immediately when no more route holds a reference to + // it. A temporary endpoint can be promoted to permanent if its address + // is added permanently. + temporary +) + type referencedNetworkEndpoint struct { ilist.Entry - refs int32 ep NetworkEndpoint nic *NIC protocol tcpip.NetworkProtocolNumber @@ -683,11 +804,34 @@ type referencedNetworkEndpoint struct { // protocol. Set to nil otherwise. linkCache LinkAddressCache - // holdsInsertRef is protected by the NIC's mutex. It indicates whether - // the reference count is biased by 1 due to the insertion of the - // endpoint. It is reset to false when RemoveAddress is called on the - // NIC. - holdsInsertRef bool + // refs is counting references held for this endpoint. When refs hits zero it + // triggers the automatic removal of the endpoint from the NIC. + refs int32 + + // networkEndpointKind must only be accessed using {get,set}Kind(). + kind networkEndpointKind +} + +func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { + return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind))) +} + +func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { + atomic.StoreInt32((*int32)(&r.kind), int32(kind)) +} + +// isValidForOutgoing returns true if the endpoint can be used to send out a +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in spoofing mode. +func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { + return r.getKind() != permanentExpired || r.nic.spoofing +} + +// isValidForIncoming returns true if the endpoint can accept an incoming +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in promiscuous mode. +func (r *referencedNetworkEndpoint) isValidForIncoming() bool { + return r.getKind() != permanentExpired || r.nic.promiscuous } // decRef decrements the ref count and cleans up the endpoint once it reaches @@ -699,11 +843,14 @@ func (r *referencedNetworkEndpoint) decRef() { } // decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. -func (r *referencedNetworkEndpoint) decRefLocked() { +// locked. Returns true if the endpoint was removed. +func (r *referencedNetworkEndpoint) decRefLocked() bool { if atomic.AddInt32(&r.refs, -1) == 0 { r.nic.removeEndpointLocked(r) + return true } + + return false } // incRef increments the ref count. It must only be called when the caller is @@ -728,3 +875,8 @@ func (r *referencedNetworkEndpoint) tryIncRef() bool { } } } + +// stack returns the Stack instance that owns the underlying endpoint. +func (r *referencedNetworkEndpoint) stack() *Stack { + return r.nic.stack +} |