diff options
Diffstat (limited to 'pkg/tcpip/stack/nic.go')
-rw-r--r-- | pkg/tcpip/stack/nic.go | 82 |
1 files changed, 51 insertions, 31 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ddc1ddab6..bb56ca9cb 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -41,6 +41,17 @@ func (l *linkResolver) confirmReachable(addr tcpip.Address) { var _ NetworkInterface = (*nic)(nil) +// TODO(https://gvisor.dev/issue/6558): Use an anonymous struct in nic for this +// once copylocks supports anonymous structs. +type packetEPs struct { + mu sync.RWMutex + + // eps is protected by the mutex, but the values contained in it are not. + // + // +checklocks:mu + eps map[tcpip.NetworkProtocolNumber]*packetEndpointList +} + // nic represents a "network interface card" to which the networking stack is // attached. type nic struct { @@ -72,10 +83,9 @@ type nic struct { sync.RWMutex spoofing bool promiscuous bool - // packetEPs is protected by mu, but the contained packetEndpointList are - // not. - packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList } + + packetEPs packetEPs } // makeNICStats initializes the NIC statistics and associates them to the global @@ -143,17 +153,21 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector), } nic.linkResQueue.init(nic) - nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) + + nic.packetEPs.mu.Lock() + defer nic.packetEPs.mu.Unlock() + + nic.packetEPs.eps = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0 // Register supported packet and network endpoint protocols. for _, netProto := range header.Ethertypes { - nic.mu.packetEPs[netProto] = new(packetEndpointList) + nic.packetEPs.eps[netProto] = new(packetEndpointList) } for _, netProto := range stack.networkProtocols { netNum := netProto.Number() - nic.mu.packetEPs[netNum] = new(packetEndpointList) + nic.packetEPs.eps[netNum] = new(packetEndpointList) netEP := netProto.NewEndpoint(nic, nic) nic.networkEndpoints[netNum] = netEP @@ -365,6 +379,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt pkt.EgressRoute = r pkt.NetworkProtocolNumber = protocol + n.deliverOutboundPacket(r.RemoteLinkAddress, pkt) + if err := n.LinkEndpoint.WritePacket(r, protocol, pkt); err != nil { return err } @@ -383,6 +399,7 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { pkt.EgressRoute = r pkt.NetworkProtocolNumber = protocol + n.deliverOutboundPacket(r.RemoteLinkAddress, pkt) } writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol) @@ -699,12 +716,9 @@ func (n *nic) isInGroup(addr tcpip.Address) bool { // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - n.mu.RLock() enabled := n.Enabled() // If the NIC is not yet enabled, don't receive any packets. if !enabled { - n.mu.RUnlock() - n.stats.disabledRx.packets.Increment() n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size())) return @@ -715,7 +729,6 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { - n.mu.RUnlock() n.stats.unknownL3ProtocolRcvdPackets.Increment() return } @@ -727,12 +740,6 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0 - // Are any packet type sockets listening for this network protocol? - protoEPs := n.mu.packetEPs[protocol] - // Other packet type sockets that are listening for all protocols. - anyEPs := n.mu.packetEPs[header.EthernetProtocolAll] - n.mu.RUnlock() - // Deliver to interested packet endpoints without holding NIC lock. var packetEPPkt *PacketBuffer deliverPacketEPs := func(ep PacketEndpoint) { @@ -758,24 +765,37 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp ep.HandlePacket(n.id, local, protocol, packetEPPkt.Clone()) } - if protoEPs != nil { + + n.packetEPs.mu.Lock() + // Are any packet type sockets listening for this network protocol? + protoEPs, protoEPsOK := n.packetEPs.eps[protocol] + // Other packet type sockets that are listening for all protocols. + anyEPs, anyEPsOK := n.packetEPs.eps[header.EthernetProtocolAll] + n.packetEPs.mu.Unlock() + + if protoEPsOK { protoEPs.forEach(deliverPacketEPs) } - if anyEPs != nil { + if anyEPsOK { anyEPs.forEach(deliverPacketEPs) } networkEndpoint.HandlePacket(pkt) } -// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. -func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - n.mu.RLock() +// deliverOutboundPacket delivers outgoing packets to interested endpoints. +func (n *nic) deliverOutboundPacket(remote tcpip.LinkAddress, pkt *PacketBuffer) { + n.packetEPs.mu.RLock() + defer n.packetEPs.mu.RUnlock() // We do not deliver to protocol specific packet endpoints as on Linux // only ETH_P_ALL endpoints get outbound packets. // Add any other packet sockets that maybe listening for all protocols. - eps := n.mu.packetEPs[header.EthernetProtocolAll] - n.mu.RUnlock() + eps, ok := n.packetEPs.eps[header.EthernetProtocolAll] + if !ok { + return + } + + local := n.LinkAddress() var packetEPPkt *PacketBuffer eps.forEach(func(ep PacketEndpoint) { @@ -796,11 +816,11 @@ func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc // Add the link layer header as outgoing packets are intercepted before // the link layer header is created and packet endpoints are interested // in the link header. - n.LinkEndpoint.AddHeader(local, remote, protocol, packetEPPkt) + n.LinkEndpoint.AddHeader(local, remote, pkt.NetworkProtocolNumber, packetEPPkt) packetEPPkt.PktType = tcpip.PacketOutgoing } - ep.HandlePacket(n.id, local, protocol, packetEPPkt.Clone()) + ep.HandlePacket(n.id, local, pkt.NetworkProtocolNumber, packetEPPkt.Clone()) }) } @@ -953,10 +973,10 @@ func (n *nic) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigura } func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() + n.packetEPs.mu.Lock() + defer n.packetEPs.mu.Unlock() - eps, ok := n.mu.packetEPs[netProto] + eps, ok := n.packetEPs.eps[netProto] if !ok { return &tcpip.ErrNotSupported{} } @@ -966,10 +986,10 @@ func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa } func (n *nic) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { - n.mu.Lock() - defer n.mu.Unlock() + n.packetEPs.mu.Lock() + defer n.packetEPs.mu.Unlock() - eps, ok := n.mu.packetEPs[netProto] + eps, ok := n.packetEPs.eps[netProto] if !ok { return } |