summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
authorLucas Manning <lucasmanning@google.com>2021-11-08 13:26:02 -0800
committergVisor bot <gvisor-bot@google.com>2021-11-08 13:28:38 -0800
commit84b38f4c6e065d3f9314a8abbb3f5857ed4fa44e (patch)
tree53eb76fa6d0612696f93ec6919185ea5a37ff3f9 /pkg/tcpip/stack
parent49d23beb283d0306c9ccf5300e73517153ddd3c2 (diff)
Add reference counting to packet buffers.
PiperOrigin-RevId: 408426639
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD13
-rw-r--r--pkg/tcpip/stack/nic.go13
-rw-r--r--pkg/tcpip/stack/packet_buffer.go39
-rw-r--r--pkg/tcpip/stack/pending_packets.go12
-rw-r--r--pkg/tcpip/stack/registration.go34
-rw-r--r--pkg/tcpip/stack/stack.go2
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go18
7 files changed, 102 insertions, 29 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 5d76adac1..81eed5b11 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -39,6 +39,17 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "packet_buffer_refs",
+ out = "packet_buffer_refs.go",
+ package = "stack",
+ prefix = "packetBuffer",
+ template = "//pkg/refsvfs2:refs_template",
+ types = {
+ "T": "PacketBuffer",
+ },
+)
+
go_library(
name = "stack",
srcs = [
@@ -59,6 +70,7 @@ go_library(
"nud.go",
"packet_buffer.go",
"packet_buffer_list.go",
+ "packet_buffer_refs.go",
"packet_buffer_unsafe.go",
"pending_packets.go",
"rand.go",
@@ -78,6 +90,7 @@ go_library(
"//pkg/ilist",
"//pkg/log",
"//pkg/rand",
+ "//pkg/refsvfs2",
"//pkg/sleep",
"//pkg/sync",
"//pkg/tcpip",
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index b9b5c35c8..7cfb836ca 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -372,7 +372,7 @@ func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, protocol tcp
}
func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
- // WritePacket takes ownership of pkt, calculate numBytes first.
+ // WritePacket modifies pkt, calculate numBytes first.
numBytes := pkt.Size()
pkt.EgressRoute = r
@@ -754,6 +754,7 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
packetEPPkt = NewPacketBuffer(PacketBufferOptions{
Data: PayloadSince(pkt.LinkHeader()).ToVectorisedView(),
})
+ defer packetEPPkt.DecRef()
// If a link header was populated in the original packet buffer, then
// populate it in the packet buffer we provide to packet endpoints as
// packet endpoints inspect link headers.
@@ -761,7 +762,9 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
packetEPPkt.PktType = tcpip.PacketHost
}
- ep.HandlePacket(n.id, local, protocol, packetEPPkt.Clone())
+ clone := packetEPPkt.Clone()
+ defer clone.DecRef()
+ ep.HandlePacket(n.id, local, protocol, clone)
}
n.packetEPs.mu.Lock()
@@ -811,14 +814,16 @@ func (n *nic) deliverOutboundPacket(remote tcpip.LinkAddress, pkt *PacketBuffer)
ReserveHeaderBytes: pkt.AvailableHeaderBytes(),
Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
})
+ defer packetEPPkt.DecRef()
// Add the link layer header as outgoing packets are intercepted before
// the link layer header is created and packet endpoints are interested
// in the link header.
n.LinkEndpoint.AddHeader(local, remote, pkt.NetworkProtocolNumber, packetEPPkt)
packetEPPkt.PktType = tcpip.PacketOutgoing
}
-
- ep.HandlePacket(n.id, local, pkt.NetworkProtocolNumber, packetEPPkt.Clone())
+ clone := packetEPPkt.Clone()
+ defer clone.DecRef()
+ ep.HandlePacket(n.id, local, pkt.NetworkProtocolNumber, clone)
})
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index c4a4bbd22..2016f7b19 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -88,6 +88,8 @@ type PacketBufferOptions struct {
type PacketBuffer struct {
_ sync.NoCopy
+ packetBufferRefs
+
// PacketBufferEntry is used to build an intrusive list of
// PacketBuffers.
PacketBufferEntry
@@ -149,6 +151,8 @@ type PacketBuffer struct {
NetworkPacketInfo NetworkPacketInfo
tuple *tuple
+
+ preserveObject bool
}
// NewPacketBuffer creates a new PacketBuffer with opts.
@@ -166,9 +170,21 @@ func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer {
if opts.IsForwardedPacket {
pk.NetworkPacketInfo.IsForwardedPacket = opts.IsForwardedPacket
}
+ pk.InitRefs()
return pk
}
+// DecRef overrides refsvfs2 DecRef and passes a nil destroy function.
+func (pk *PacketBuffer) DecRef() {
+ pk.packetBufferRefs.DecRef(nil)
+}
+
+// PreserveObject marks this PacketBuffer so it is not recycled by internal
+// pooling.
+func (pk *PacketBuffer) PreserveObject() {
+ pk.preserveObject = true
+}
+
// ReservedHeaderBytes returns the number of bytes initially reserved for
// headers.
func (pk *PacketBuffer) ReservedHeaderBytes() int {
@@ -291,7 +307,7 @@ func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View {
// Clone makes a semi-deep copy of pk. The underlying packet payload is
// shared. Hence, no modifications is done to underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
- return &PacketBuffer{
+ newPk := &PacketBuffer{
PacketBufferEntry: pk.PacketBufferEntry,
buf: pk.buf.Clone(),
reserved: pk.reserved,
@@ -311,6 +327,8 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
NetworkPacketInfo: pk.NetworkPacketInfo,
tuple: pk.tuple,
}
+ newPk.InitRefs()
+ return newPk
}
// Network returns the network header as a header.Network.
@@ -339,6 +357,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
reserved: pk.AvailableHeaderBytes(),
tuple: pk.tuple,
}
+ newPk.InitRefs()
return newPk
}
@@ -375,6 +394,22 @@ func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBu
return newPk
}
+// IncRef increases the reference count on each PacketBuffer
+// stored in the PacketBufferList.
+func (pk *PacketBufferList) IncRef() {
+ for pb := pk.Front(); pb != nil; pb = pb.Next() {
+ pb.IncRef()
+ }
+}
+
+// DecRef decreases the reference count on each PacketBuffer
+// stored in the PacketBufferList.
+func (pk *PacketBufferList) DecRef() {
+ for pb := pk.Front(); pb != nil; pb = pb.Next() {
+ pb.DecRef()
+ }
+}
+
// headerInfo stores metadata about a header in a packet.
type headerInfo struct {
// offset is the offset of the header in pk.buf relative to
@@ -460,7 +495,7 @@ func (d PacketData) AppendView(v tcpipbuffer.View) {
d.pk.buf.AppendOwned(v)
}
-// MergeFragment appends the data portion of frag to dst. It takes ownership of
+// MergeFragment appends the data portion of frag to dst. It modifies
// frag and frag should not be used again.
func MergeFragment(dst, frag *PacketBuffer) {
frag.buf.TrimFront(int64(frag.dataOffset()))
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index 13e8907ec..7e18d4bc4 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -152,6 +152,12 @@ func (f *packetsPendingLinkResolution) enqueue(r *Route, proto tcpip.NetworkProt
proto: proto,
pkt: pkt,
})
+ switch pkt := pkt.(type) {
+ case *PacketBuffer:
+ pkt.IncRef()
+ case *PacketBufferList:
+ pkt.IncRef()
+ }
if len(packets) > maxPendingPacketsPerResolution {
f.incrementOutgoingPacketErrors(packets[0].proto, packets[0].pkt)
@@ -226,5 +232,11 @@ func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, l
}
}
}
+ switch pkt := p.pkt.(type) {
+ case *PacketBuffer:
+ pkt.DecRef()
+ case *PacketBufferList:
+ pkt.DecRef()
+ }
}
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 31b3a554d..5db9ad1b1 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -102,12 +102,12 @@ type TransportEndpoint interface {
// HandlePacket is called by the stack when new packets arrive to this
// transport endpoint. It sets the packet buffer's transport header.
//
- // HandlePacket takes ownership of the packet.
+ // HandlePacket may modify the packet.
HandlePacket(TransportEndpointID, *PacketBuffer)
// HandleError is called when the transport endpoint receives an error.
//
- // HandleError takes ownership of the packet buffer.
+ // HandleError takes may modify the packet buffer.
HandleError(TransportError, *PacketBuffer)
// Abort initiates an expedited endpoint teardown. It puts the endpoint
@@ -135,7 +135,7 @@ type RawTransportEndpoint interface {
// this transport endpoint. The packet contains all data from the link
// layer up.
//
- // HandlePacket takes ownership of the packet.
+ // HandlePacket may modify the packet.
HandlePacket(*PacketBuffer)
}
@@ -153,7 +153,7 @@ type PacketEndpoint interface {
// linkHeader may have a length of 0, in which case the PacketEndpoint
// should construct its own ethernet header for applications.
//
- // HandlePacket takes ownership of pkt.
+ // HandlePacket may modify pkt.
HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
@@ -202,7 +202,7 @@ type TransportProtocol interface {
// protocol that don't match any existing endpoint. For example,
// it is targeted at a port that has no listeners.
//
- // HandleUnknownDestinationPacket takes ownership of the packet if it handles
+ // HandleUnknownDestinationPacket may modify the packet if it handles
// the issue.
HandleUnknownDestinationPacket(TransportEndpointID, *PacketBuffer) UnknownDestinationPacketDisposition
@@ -257,13 +257,13 @@ type TransportDispatcher interface {
//
// pkt.NetworkHeader must be set before calling DeliverTransportPacket.
//
- // DeliverTransportPacket takes ownership of the packet.
+ // DeliverTransportPacket may modify the packet.
DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition
// DeliverTransportError delivers an error to the appropriate transport
// endpoint.
//
- // DeliverTransportError takes ownership of the packet buffer.
+ // DeliverTransportError may modify the packet buffer.
DeliverTransportError(local, remote tcpip.Address, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ TransportError, _ *PacketBuffer)
// DeliverRawPacket delivers a packet to any subscribed raw sockets.
@@ -570,14 +570,14 @@ type NetworkInterface interface {
// WritePacket writes a packet with the given protocol through the given
// route.
//
- // WritePacket takes ownership of the packet buffer. The packet buffer's
+ // WritePacket may modify the packet buffer. The packet buffer's
// network and transport header must be set.
WritePacket(*Route, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
// WritePackets writes packets with the given protocol through the given
// route. Must not be called with an empty list of packet buffers.
//
- // WritePackets takes ownership of the packet buffers.
+ // WritePackets may modify the packet buffers.
//
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, syscall filters
@@ -636,23 +636,23 @@ type NetworkEndpoint interface {
MaxHeaderLength() uint16
// WritePacket writes a packet to the given destination address and
- // protocol. It takes ownership of pkt. pkt.TransportHeader must have
+ // protocol. It may modify pkt. pkt.TransportHeader must have
// already been set.
WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error
// WritePackets writes packets to the given destination address and
- // protocol. pkts must not be zero length. It takes ownership of pkts and
+ // protocol. pkts must not be zero length. It may modify pkts and
// underlying packets.
WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
- // header to the given destination address. It takes ownership of pkt.
+ // header to the given destination address. It may modify pkt.
WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error
// HandlePacket is called by the link layer when new packets arrive to
// this network endpoint. It sets pkt.NetworkHeader.
//
- // HandlePacket takes ownership of pkt.
+ // HandlePacket may modify pkt.
HandlePacket(pkt *PacketBuffer)
// Close is called when the endpoint is removed from a stack.
@@ -748,7 +748,7 @@ type NetworkDispatcher interface {
// DeliverNetworkPacket. Some packets do not have link headers (e.g.
// packets sent via loopback), and won't have the field set.
//
- // DeliverNetworkPacket takes ownership of pkt.
+ // DeliverNetworkPacket may modify pkt.
DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
@@ -836,7 +836,7 @@ type LinkEndpoint interface {
// WritePacket writes a packet with the given protocol and route.
//
- // WritePacket takes ownership of the packet buffer. The packet buffer's
+ // WritePacket may modify the packet buffer. The packet buffer's
// network and transport header must be set.
//
// To participate in transparent bridging, a LinkEndpoint implementation
@@ -847,7 +847,7 @@ type LinkEndpoint interface {
// WritePackets writes packets with the given protocol and route. Must not be
// called with an empty list of packet buffers.
//
- // WritePackets takes ownership of the packet buffers.
+ // WritePackets may modify the packet buffers.
//
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, syscall filters
@@ -859,7 +859,7 @@ type LinkEndpoint interface {
// If the link-layer has its own header, the payload must already include the
// header.
//
- // WriteRawPacket takes ownership of the packet.
+ // WriteRawPacket may modify the packet.
WriteRawPacket(*PacketBuffer) tcpip.Error
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index a05fd7036..3ddf9de6b 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1622,6 +1622,7 @@ func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress,
ReserveHeaderBytes: int(nic.MaxHeaderLength()),
Data: payload,
})
+ defer pkt.DecRef()
pkt.NetworkProtocolNumber = netProto
return nic.WritePacketToRemote(remote, netProto, pkt)
}
@@ -1639,6 +1640,7 @@ func (s *Stack) WriteRawPacket(nicID tcpip.NICID, proto tcpip.NetworkProtocolNum
pkt := NewPacketBuffer(PacketBufferOptions{
Data: payload,
})
+ defer pkt.DecRef()
pkt.NetworkProtocolNumber = proto
return nic.WriteRawPacket(pkt)
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 3474c292a..088913b83 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -401,14 +401,16 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32)
func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
ep.mu.RLock()
queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
- // HandlePacket takes ownership of pkt, so each endpoint needs
+ // HandlePacket may modify pkt, so each endpoint needs
// its own copy except for the final one.
for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] {
+ clone := pkt.Clone()
if mustQueue {
- queuedProtocol.QueuePacket(endpoint, id, pkt.Clone())
+ queuedProtocol.QueuePacket(endpoint, id, clone)
} else {
- endpoint.HandlePacket(id, pkt.Clone())
+ endpoint.HandlePacket(id, clone)
}
+ clone.DecRef()
}
if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue {
queuedProtocol.QueuePacket(endpoint, id, pkt)
@@ -559,10 +561,12 @@ func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber,
d.stack.stats.UDP.UnknownPortErrors.Increment()
return false
}
- // handlePacket takes ownership of pkt, so each endpoint needs its own
+ // handlePacket takes may modify pkt, so each endpoint needs its own
// copy except for the final one.
for _, ep := range destEPs[:len(destEPs)-1] {
- ep.handlePacket(id, pkt.Clone())
+ clone := pkt.Clone()
+ ep.handlePacket(id, clone)
+ clone.DecRef()
}
destEPs[len(destEPs)-1].handlePacket(id, pkt)
return true
@@ -615,7 +619,9 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb
for _, rawEP := range rawEPs {
// Each endpoint gets its own copy of the packet for the sake
// of save/restore.
- rawEP.HandlePacket(pkt.Clone())
+ clone := pkt.Clone()
+ rawEP.HandlePacket(clone)
+ clone.DecRef()
}
return len(rawEPs) != 0