summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/nic.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/nic.go')
-rw-r--r--pkg/tcpip/stack/nic.go148
1 files changed, 103 insertions, 45 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index b854d868c..29d580e76 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -72,9 +72,15 @@ type nic struct {
sync.RWMutex
spoofing bool
promiscuous bool
- // packetEPs is protected by mu, but the contained packetEndpointList are
- // not.
- packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList
+ }
+
+ packetEPs struct {
+ mu sync.RWMutex
+
+ // eps is protected by the mutex, but the values contained in it are not.
+ //
+ // +checklocks:mu
+ eps map[tcpip.NetworkProtocolNumber]*packetEndpointList
}
}
@@ -91,6 +97,8 @@ type packetEndpointList struct {
mu sync.RWMutex
// eps is protected by mu, but the contained PacketEndpoint values are not.
+ //
+ // +checklocks:mu
eps []PacketEndpoint
}
@@ -111,6 +119,12 @@ func (p *packetEndpointList) remove(ep PacketEndpoint) {
}
}
+func (p *packetEndpointList) len() int {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ return len(p.eps)
+}
+
// forEach calls fn with each endpoints in p while holding the read lock on p.
func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) {
p.mu.RLock()
@@ -143,18 +157,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector),
}
nic.linkResQueue.init(nic)
- nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
+
+ nic.packetEPs.mu.Lock()
+ defer nic.packetEPs.mu.Unlock()
+
+ nic.packetEPs.eps = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0
- // Register supported packet and network endpoint protocols.
- for _, netProto := range header.Ethertypes {
- nic.mu.packetEPs[netProto] = new(packetEndpointList)
- }
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
- nic.mu.packetEPs[netNum] = new(packetEndpointList)
-
netEP := netProto.NewEndpoint(nic, nic)
nic.networkEndpoints[netNum] = netEP
@@ -365,6 +377,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt
pkt.EgressRoute = r
pkt.NetworkProtocolNumber = protocol
+ n.deliverOutboundPacket(r.RemoteLinkAddress, pkt)
+
if err := n.LinkEndpoint.WritePacket(r, protocol, pkt); err != nil {
return err
}
@@ -383,6 +397,7 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
pkt.EgressRoute = r
pkt.NetworkProtocolNumber = protocol
+ n.deliverOutboundPacket(r.RemoteLinkAddress, pkt)
}
writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol)
@@ -501,7 +516,7 @@ func (n *nic) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber,
// addAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
-func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
+func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error {
ep, ok := n.networkEndpoints[protocolAddress.Protocol]
if !ok {
return &tcpip.ErrUnknownProtocol{}
@@ -512,7 +527,7 @@ func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
return &tcpip.ErrNotSupported{}
}
- addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
+ addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, properties)
if err == nil {
// We have no need for the address endpoint.
addressEndpoint.DecRef()
@@ -699,12 +714,9 @@ func (n *nic) isInGroup(addr tcpip.Address) bool {
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- n.mu.RLock()
enabled := n.Enabled()
// If the NIC is not yet enabled, don't receive any packets.
if !enabled {
- n.mu.RUnlock()
-
n.stats.disabledRx.packets.Increment()
n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size()))
return
@@ -715,7 +727,6 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
networkEndpoint, ok := n.networkEndpoints[protocol]
if !ok {
- n.mu.RUnlock()
n.stats.unknownL3ProtocolRcvdPackets.Increment()
return
}
@@ -727,44 +738,87 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0
- // Are any packet type sockets listening for this network protocol?
- protoEPs := n.mu.packetEPs[protocol]
- // Other packet type sockets that are listening for all protocols.
- anyEPs := n.mu.packetEPs[header.EthernetProtocolAll]
- n.mu.RUnlock()
-
// Deliver to interested packet endpoints without holding NIC lock.
+ var packetEPPkt *PacketBuffer
deliverPacketEPs := func(ep PacketEndpoint) {
- p := pkt.Clone()
- p.PktType = tcpip.PacketHost
- ep.HandlePacket(n.id, local, protocol, p)
+ if packetEPPkt == nil {
+ // Packet endpoints hold the full packet.
+ //
+ // We perform a deep copy because higher-level endpoints may point to
+ // the middle of a view that is held by a packet endpoint. Save/Restore
+ // does not support overlapping slices and will panic in this case.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports
+ // overlapping slices (e.g. by passing a shallow copy of pkt to the packet
+ // endpoint).
+ packetEPPkt = NewPacketBuffer(PacketBufferOptions{
+ Data: PayloadSince(pkt.LinkHeader()).ToVectorisedView(),
+ })
+ // If a link header was populated in the original packet buffer, then
+ // populate it in the packet buffer we provide to packet endpoints as
+ // packet endpoints inspect link headers.
+ packetEPPkt.LinkHeader().Consume(pkt.LinkHeader().View().Size())
+ packetEPPkt.PktType = tcpip.PacketHost
+ }
+
+ ep.HandlePacket(n.id, local, protocol, packetEPPkt.Clone())
}
- if protoEPs != nil {
+
+ n.packetEPs.mu.Lock()
+ // Are any packet type sockets listening for this network protocol?
+ protoEPs, protoEPsOK := n.packetEPs.eps[protocol]
+ // Other packet type sockets that are listening for all protocols.
+ anyEPs, anyEPsOK := n.packetEPs.eps[header.EthernetProtocolAll]
+ n.packetEPs.mu.Unlock()
+
+ if protoEPsOK {
protoEPs.forEach(deliverPacketEPs)
}
- if anyEPs != nil {
+ if anyEPsOK {
anyEPs.forEach(deliverPacketEPs)
}
networkEndpoint.HandlePacket(pkt)
}
-// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
-func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- n.mu.RLock()
+// deliverOutboundPacket delivers outgoing packets to interested endpoints.
+func (n *nic) deliverOutboundPacket(remote tcpip.LinkAddress, pkt *PacketBuffer) {
+ n.packetEPs.mu.RLock()
+ defer n.packetEPs.mu.RUnlock()
// We do not deliver to protocol specific packet endpoints as on Linux
// only ETH_P_ALL endpoints get outbound packets.
// Add any other packet sockets that maybe listening for all protocols.
- eps := n.mu.packetEPs[header.EthernetProtocolAll]
- n.mu.RUnlock()
+ eps, ok := n.packetEPs.eps[header.EthernetProtocolAll]
+ if !ok {
+ return
+ }
+
+ local := n.LinkAddress()
+ var packetEPPkt *PacketBuffer
eps.forEach(func(ep PacketEndpoint) {
- p := pkt.Clone()
- p.PktType = tcpip.PacketOutgoing
- // Add the link layer header as outgoing packets are intercepted
- // before the link layer header is created.
- n.LinkEndpoint.AddHeader(local, remote, protocol, p)
- ep.HandlePacket(n.id, local, protocol, p)
+ if packetEPPkt == nil {
+ // Packet endpoints hold the full packet.
+ //
+ // We perform a deep copy because higher-level endpoints may point to
+ // the middle of a view that is held by a packet endpoint. Save/Restore
+ // does not support overlapping slices and will panic in this case.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports
+ // overlapping slices (e.g. by passing a shallow copy of pkt to the packet
+ // endpoint).
+ packetEPPkt = NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: pkt.AvailableHeaderBytes(),
+ Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
+ })
+ // Add the link layer header as outgoing packets are intercepted before
+ // the link layer header is created and packet endpoints are interested
+ // in the link header.
+ n.LinkEndpoint.AddHeader(local, remote, pkt.NetworkProtocolNumber, packetEPPkt)
+ packetEPPkt.PktType = tcpip.PacketOutgoing
+ }
+
+ ep.HandlePacket(n.id, local, pkt.NetworkProtocolNumber, packetEPPkt.Clone())
})
}
@@ -917,12 +971,13 @@ func (n *nic) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigura
}
func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
+ n.packetEPs.mu.Lock()
+ defer n.packetEPs.mu.Unlock()
- eps, ok := n.mu.packetEPs[netProto]
+ eps, ok := n.packetEPs.eps[netProto]
if !ok {
- return &tcpip.ErrNotSupported{}
+ eps = new(packetEndpointList)
+ n.packetEPs.eps[netProto] = eps
}
eps.add(ep)
@@ -930,14 +985,17 @@ func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
}
func (n *nic) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
- n.mu.Lock()
- defer n.mu.Unlock()
+ n.packetEPs.mu.Lock()
+ defer n.packetEPs.mu.Unlock()
- eps, ok := n.mu.packetEPs[netProto]
+ eps, ok := n.packetEPs.eps[netProto]
if !ok {
return
}
eps.remove(ep)
+ if eps.len() == 0 {
+ delete(n.packetEPs.eps, netProto)
+ }
}
// isValidForOutgoing returns true if the endpoint can be used to send out a