diff options
Diffstat (limited to 'pkg/tcpip/stack/nic.go')
-rw-r--r-- | pkg/tcpip/stack/nic.go | 297 |
1 files changed, 160 insertions, 137 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dcd4319bf..5d037a27e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -54,18 +54,20 @@ type NIC struct { sync.RWMutex spoofing bool promiscuous bool - // packetEPs is protected by mu, but the contained PacketEndpoint - // values are not. - packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint + // packetEPs is protected by mu, but the contained packetEndpointList are + // not. + packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList } } -// NICStats includes transmitted and received stats. +// NICStats hold statistics for a NIC. type NICStats struct { Tx DirectionStats Rx DirectionStats DisabledRx DirectionStats + + Neighbor NeighborStats } func makeNICStats() NICStats { @@ -80,6 +82,39 @@ type DirectionStats struct { Bytes *tcpip.StatCounter } +type packetEndpointList struct { + mu sync.RWMutex + + // eps is protected by mu, but the contained PacketEndpoint values are not. + eps []PacketEndpoint +} + +func (p *packetEndpointList) add(ep PacketEndpoint) { + p.mu.Lock() + defer p.mu.Unlock() + p.eps = append(p.eps, ep) +} + +func (p *packetEndpointList) remove(ep PacketEndpoint) { + p.mu.Lock() + defer p.mu.Unlock() + for i, epOther := range p.eps { + if epOther == ep { + p.eps = append(p.eps[:i], p.eps[i+1:]...) + break + } + } +} + +// forEach calls fn with each endpoints in p while holding the read lock on p. +func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) { + p.mu.RLock() + defer p.mu.RUnlock() + for _, ep := range p.eps { + fn(ep) + } +} + // newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For @@ -100,7 +135,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } - nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) + nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) // Check for Neighbor Unreachability Detection support. var nud NUDHandler @@ -123,11 +158,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // Register supported packet and network endpoint protocols. for _, netProto := range header.Ethertypes { - nic.mu.packetEPs[netProto] = []PacketEndpoint{} + nic.mu.packetEPs[netProto] = new(packetEndpointList) } for _, netProto := range stack.networkProtocols { netNum := netProto.Number() - nic.mu.packetEPs[netNum] = nil + nic.mu.packetEPs[netNum] = new(packetEndpointList) nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } @@ -170,7 +205,7 @@ func (n *NIC) disable() { // // n MUST be locked. func (n *NIC) disableLocked() { - if !n.setEnabled(false) { + if !n.Enabled() { return } @@ -182,6 +217,10 @@ func (n *NIC) disableLocked() { for _, ep := range n.networkEndpoints { ep.Disable() } + + if !n.setEnabled(false) { + panic("should have only done work to disable the NIC if it was enabled") + } } // enable enables n. @@ -232,7 +271,8 @@ func (n *NIC) setPromiscuousMode(enable bool) { n.mu.Unlock() } -func (n *NIC) isPromiscuousMode() bool { +// Promiscuous implements NetworkInterface. +func (n *NIC) Promiscuous() bool { n.mu.RLock() rv := n.mu.promiscuous n.mu.RUnlock() @@ -264,7 +304,7 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { r := r.Clone() - n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt) + n.stack.linkResQueue.enqueue(ch, r, protocol, pkt) return nil } return err @@ -273,6 +313,15 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb return n.writePacket(r, gso, protocol, pkt) } +// WritePacketToRemote implements NetworkInterface. +func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + r := Route{ + NetProto: protocol, + } + r.ResolveWith(remoteLinkAddr) + return n.writePacket(&r, gso, protocol, pkt) +} + func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() @@ -311,16 +360,21 @@ func (n *NIC) setSpoofing(enable bool) { // primaryAddress returns an address that can be used to communicate with // remoteAddr. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { - n.mu.RLock() - spoofing := n.mu.spoofing - n.mu.RUnlock() - ep, ok := n.networkEndpoints[protocol] if !ok { return nil } - return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return nil + } + + n.mu.RLock() + spoofing := n.mu.spoofing + n.mu.RUnlock() + + return addressableEndpoint.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) } type getAddressBehaviour int @@ -339,6 +393,16 @@ func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } +func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + ep := n.getAddressOrCreateTempInner(protocol, addr, false, NeverPrimaryEndpoint) + if ep != nil { + ep.DecRef() + return true + } + + return false +} + // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { return n.getAddressOrCreateTemp(protocol, address, peb, spoofing) @@ -369,11 +433,17 @@ func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, addre // getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean // is passed to indicate whether or not we should generate temporary endpoints. func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { - if ep, ok := n.networkEndpoints[protocol]; ok { - return ep.AcquireAssignedAddress(address, createTemp, peb) + ep, ok := n.networkEndpoints[protocol] + if !ok { + return nil } - return nil + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return nil + } + + return addressableEndpoint.AcquireAssignedAddress(address, createTemp, peb) } // addAddress adds a new address to n, so that it starts accepting packets @@ -384,7 +454,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo return tcpip.ErrUnknownProtocol } - addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return tcpip.ErrNotSupported + } + + addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) if err == nil { // We have no need for the address endpoint. addressEndpoint.DecRef() @@ -397,7 +472,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { - for _, a := range ep.PermanentAddresses() { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + for _, a := range addressableEndpoint.PermanentAddresses() { addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } @@ -408,7 +488,12 @@ func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { - for _, a := range ep.PrimaryAddresses() { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + for _, a := range addressableEndpoint.PrimaryAddresses() { addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } @@ -426,13 +511,23 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit return tcpip.AddressWithPrefix{} } - return ep.MainAddress() + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return tcpip.AddressWithPrefix{} + } + + return addressableEndpoint.MainAddress() } // removeAddress removes an address from n. func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { for _, ep := range n.networkEndpoints { - if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { continue } else { return err @@ -505,8 +600,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address return tcpip.ErrNotSupported } - _, err := gep.JoinGroup(addr) - return err + return gep.JoinGroup(addr) } // leaveGroup decrements the count for the given multicast address, and when it @@ -522,11 +616,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres return tcpip.ErrNotSupported } - if _, err := gep.LeaveGroup(addr); err != nil { - return err - } - - return nil + return gep.LeaveGroup(addr) } // isInGroup returns true if n has joined the multicast group addr. @@ -545,13 +635,6 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return false } -func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) - defer r.Release() - r.RemoteLinkAddress = remotelinkAddr - n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) -} - // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when // the NIC receives a packet from the link endpoint. @@ -573,7 +656,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n.stats.Rx.Packets.Increment() n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) - netProto, ok := n.stack.networkProtocols[protocol] + networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -585,23 +668,29 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if local == "" { local = n.LinkEndpoint.LinkAddress() } + pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0 // Are any packet type sockets listening for this network protocol? - packetEPs := n.mu.packetEPs[protocol] - // Add any other packet type sockets that may be listening for all protocols. - packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) + protoEPs := n.mu.packetEPs[protocol] + // Other packet type sockets that are listening for all protocols. + anyEPs := n.mu.packetEPs[header.EthernetProtocolAll] n.mu.RUnlock() - for _, ep := range packetEPs { + + // Deliver to interested packet endpoints without holding NIC lock. + deliverPacketEPs := func(ep PacketEndpoint) { p := pkt.Clone() p.PktType = tcpip.PacketHost ep.HandlePacket(n.id, local, protocol, p) } - - if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { - n.stack.stats.IP.PacketsReceived.Increment() + if protoEPs != nil { + protoEPs.forEach(deliverPacketEPs) + } + if anyEPs != nil { + anyEPs.forEach(deliverPacketEPs) } // Parse headers. + netProto := n.stack.NetworkProtocolInstance(protocol) transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) if !ok { // The packet is too small to contain a network header. @@ -616,9 +705,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if n.stack.handleLocal && !n.IsLoopback() { + src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View()) if r := n.getAddress(protocol, src); r != nil { r.DecRef() @@ -631,78 +719,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - // Loopback traffic skips the prerouting chain. - if !n.IsLoopback() { - // iptables filtering. - ipt := n.stack.IPTables() - address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { - // iptables is telling us to drop the packet. - n.stack.stats.IP.IPTablesPreroutingDropped.Increment() - return - } - } - - if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil { - n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt) - return - } - - // This NIC doesn't care about the packet. Find a NIC that cares about the - // packet and forward it to the NIC. - // - // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding(protocol) { - r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) - if err != nil { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - return - } - - // Found a NIC. - n := r.nic - if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { - if n.isValidForOutgoing(addressEndpoint) { - r.LocalLinkAddress = n.LinkEndpoint.LinkAddress() - r.RemoteLinkAddress = remote - r.RemoteAddress = src - // TODO(b/123449044): Update the source NIC as well. - n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) - addressEndpoint.DecRef() - r.Release() - return - } - - addressEndpoint.DecRef() - } - - // n doesn't have a destination endpoint. - // Send the packet out of n. - // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. - - // pkt may have set its header and may not have enough headroom for - // link-layer header for the other link to prepend. Here we create a new - // packet to forward. - fwdPkt := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()), - // We need to do a deep copy of the IP packet because WritePacket (and - // friends) take ownership of the packet buffer, but we do not own it. - Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), - }) - - // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - } - - r.Release() - return - } - - // If a packet socket handled the packet, don't treat it as invalid. - if len(packetEPs) == 0 { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - } + networkEndpoint.HandlePacket(pkt) } // DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. @@ -711,21 +728,22 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc // We do not deliver to protocol specific packet endpoints as on Linux // only ETH_P_ALL endpoints get outbound packets. // Add any other packet sockets that maybe listening for all protocols. - packetEPs := n.mu.packetEPs[header.EthernetProtocolAll] + eps := n.mu.packetEPs[header.EthernetProtocolAll] n.mu.RUnlock() - for _, ep := range packetEPs { + + eps.forEach(func(ep PacketEndpoint) { p := pkt.Clone() p.PktType = tcpip.PacketOutgoing // Add the link layer header as outgoing packets are intercepted // before the link layer header is created. n.LinkEndpoint.AddHeader(local, remote, protocol, p) ep.HandlePacket(n.id, local, protocol, p) - } + }) } // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { +func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -737,7 +755,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - n.stack.demux.deliverRawPacket(r, protocol, pkt) + n.stack.demux.deliverRawPacket(protocol, pkt) // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. @@ -766,14 +784,25 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN return TransportPacketHandled } - id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.stack.demux.deliverPacket(r, protocol, pkt, id) { + netProto, ok := n.stack.networkProtocols[pkt.NetworkProtocolNumber] + if !ok { + panic(fmt.Sprintf("expected network protocol = %d, have = %#v", pkt.NetworkProtocolNumber, n.stack.networkProtocolNumbers())) + } + + src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) + id := TransportEndpointID{ + LocalPort: dstPort, + LocalAddress: dst, + RemotePort: srcPort, + RemoteAddress: src, + } + if n.stack.demux.deliverPacket(protocol, pkt, id) { return TransportPacketHandled } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, pkt) { + if state.defaultHandler(id, pkt) { return TransportPacketHandled } } @@ -781,7 +810,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // We could not find an appropriate destination for this packet so // give the protocol specific error handler a chance to handle it. // If it doesn't handle it then we should do so. - switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res { + switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: n.stack.stats.MalformedRcvdPackets.Increment() return TransportPacketHandled @@ -862,7 +891,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa if !ok { return tcpip.ErrNotSupported } - n.mu.packetEPs[netProto] = append(eps, ep) + eps.add(ep) return nil } @@ -875,17 +904,11 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep if !ok { return } - - for i, epOther := range eps { - if epOther == ep { - n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) - return - } - } + eps.remove(ep) } // isValidForOutgoing returns true if the endpoint can be used to send out a -// packet. It requires the endpoint to not be marked expired (i.e., its address) +// packet. It requires the endpoint to not be marked expired (i.e., its address // has been removed) unless the NIC is in spoofing mode, or temporary. func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { n.mu.RLock() |