diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/nic.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 39 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer_refs.go | 140 | ||||
-rw-r--r-- | pkg/tcpip/stack/pending_packets.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 34 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_state_autogen.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 18 |
8 files changed, 254 insertions, 29 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index b9b5c35c8..7cfb836ca 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -372,7 +372,7 @@ func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, protocol tcp } func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { - // WritePacket takes ownership of pkt, calculate numBytes first. + // WritePacket modifies pkt, calculate numBytes first. numBytes := pkt.Size() pkt.EgressRoute = r @@ -754,6 +754,7 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp packetEPPkt = NewPacketBuffer(PacketBufferOptions{ Data: PayloadSince(pkt.LinkHeader()).ToVectorisedView(), }) + defer packetEPPkt.DecRef() // If a link header was populated in the original packet buffer, then // populate it in the packet buffer we provide to packet endpoints as // packet endpoints inspect link headers. @@ -761,7 +762,9 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp packetEPPkt.PktType = tcpip.PacketHost } - ep.HandlePacket(n.id, local, protocol, packetEPPkt.Clone()) + clone := packetEPPkt.Clone() + defer clone.DecRef() + ep.HandlePacket(n.id, local, protocol, clone) } n.packetEPs.mu.Lock() @@ -811,14 +814,16 @@ func (n *nic) deliverOutboundPacket(remote tcpip.LinkAddress, pkt *PacketBuffer) ReserveHeaderBytes: pkt.AvailableHeaderBytes(), Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), }) + defer packetEPPkt.DecRef() // Add the link layer header as outgoing packets are intercepted before // the link layer header is created and packet endpoints are interested // in the link header. n.LinkEndpoint.AddHeader(local, remote, pkt.NetworkProtocolNumber, packetEPPkt) packetEPPkt.PktType = tcpip.PacketOutgoing } - - ep.HandlePacket(n.id, local, pkt.NetworkProtocolNumber, packetEPPkt.Clone()) + clone := packetEPPkt.Clone() + defer clone.DecRef() + ep.HandlePacket(n.id, local, pkt.NetworkProtocolNumber, clone) }) } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index c4a4bbd22..2016f7b19 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -88,6 +88,8 @@ type PacketBufferOptions struct { type PacketBuffer struct { _ sync.NoCopy + packetBufferRefs + // PacketBufferEntry is used to build an intrusive list of // PacketBuffers. PacketBufferEntry @@ -149,6 +151,8 @@ type PacketBuffer struct { NetworkPacketInfo NetworkPacketInfo tuple *tuple + + preserveObject bool } // NewPacketBuffer creates a new PacketBuffer with opts. @@ -166,9 +170,21 @@ func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { if opts.IsForwardedPacket { pk.NetworkPacketInfo.IsForwardedPacket = opts.IsForwardedPacket } + pk.InitRefs() return pk } +// DecRef overrides refsvfs2 DecRef and passes a nil destroy function. +func (pk *PacketBuffer) DecRef() { + pk.packetBufferRefs.DecRef(nil) +} + +// PreserveObject marks this PacketBuffer so it is not recycled by internal +// pooling. +func (pk *PacketBuffer) PreserveObject() { + pk.preserveObject = true +} + // ReservedHeaderBytes returns the number of bytes initially reserved for // headers. func (pk *PacketBuffer) ReservedHeaderBytes() int { @@ -291,7 +307,7 @@ func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View { // Clone makes a semi-deep copy of pk. The underlying packet payload is // shared. Hence, no modifications is done to underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { - return &PacketBuffer{ + newPk := &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, buf: pk.buf.Clone(), reserved: pk.reserved, @@ -311,6 +327,8 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { NetworkPacketInfo: pk.NetworkPacketInfo, tuple: pk.tuple, } + newPk.InitRefs() + return newPk } // Network returns the network header as a header.Network. @@ -339,6 +357,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { reserved: pk.AvailableHeaderBytes(), tuple: pk.tuple, } + newPk.InitRefs() return newPk } @@ -375,6 +394,22 @@ func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBu return newPk } +// IncRef increases the reference count on each PacketBuffer +// stored in the PacketBufferList. +func (pk *PacketBufferList) IncRef() { + for pb := pk.Front(); pb != nil; pb = pb.Next() { + pb.IncRef() + } +} + +// DecRef decreases the reference count on each PacketBuffer +// stored in the PacketBufferList. +func (pk *PacketBufferList) DecRef() { + for pb := pk.Front(); pb != nil; pb = pb.Next() { + pb.DecRef() + } +} + // headerInfo stores metadata about a header in a packet. type headerInfo struct { // offset is the offset of the header in pk.buf relative to @@ -460,7 +495,7 @@ func (d PacketData) AppendView(v tcpipbuffer.View) { d.pk.buf.AppendOwned(v) } -// MergeFragment appends the data portion of frag to dst. It takes ownership of +// MergeFragment appends the data portion of frag to dst. It modifies // frag and frag should not be used again. func MergeFragment(dst, frag *PacketBuffer) { frag.buf.TrimFront(int64(frag.dataOffset())) diff --git a/pkg/tcpip/stack/packet_buffer_refs.go b/pkg/tcpip/stack/packet_buffer_refs.go new file mode 100644 index 000000000..c756afd64 --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_refs.go @@ -0,0 +1,140 @@ +package stack + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const packetBufferenableLogging = false + +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var packetBufferobj *PacketBuffer + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// +// +stateify savable +type packetBufferRefs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *packetBufferRefs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *packetBufferRefs) RefType() string { + return fmt.Sprintf("%T", packetBufferobj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *packetBufferRefs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *packetBufferRefs) LogRefs() bool { + return packetBufferenableLogging +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *packetBufferRefs) ReadRefs() int64 { + return atomic.LoadInt64(&r.refCount) +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *packetBufferRefs) IncRef() { + v := atomic.AddInt64(&r.refCount, 1) + if packetBufferenableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) + } +} + +// TryIncRef implements refs.TryRefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *packetBufferRefs) TryIncRef() bool { + const speculativeRef = 1 << 32 + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if packetBufferenableLogging { + refsvfs2.LogTryIncRef(r, v) + } + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *packetBufferRefs) DecRef(destroy func()) { + v := atomic.AddInt64(&r.refCount, -1) + if packetBufferenableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) + + case v == 0: + refsvfs2.Unregister(r) + + if destroy != nil { + destroy() + } + } +} + +func (r *packetBufferRefs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 13e8907ec..7e18d4bc4 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -152,6 +152,12 @@ func (f *packetsPendingLinkResolution) enqueue(r *Route, proto tcpip.NetworkProt proto: proto, pkt: pkt, }) + switch pkt := pkt.(type) { + case *PacketBuffer: + pkt.IncRef() + case *PacketBufferList: + pkt.IncRef() + } if len(packets) > maxPendingPacketsPerResolution { f.incrementOutgoingPacketErrors(packets[0].proto, packets[0].pkt) @@ -226,5 +232,11 @@ func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, l } } } + switch pkt := p.pkt.(type) { + case *PacketBuffer: + pkt.DecRef() + case *PacketBufferList: + pkt.DecRef() + } } } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 31b3a554d..5db9ad1b1 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -102,12 +102,12 @@ type TransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to this // transport endpoint. It sets the packet buffer's transport header. // - // HandlePacket takes ownership of the packet. + // HandlePacket may modify the packet. HandlePacket(TransportEndpointID, *PacketBuffer) // HandleError is called when the transport endpoint receives an error. // - // HandleError takes ownership of the packet buffer. + // HandleError takes may modify the packet buffer. HandleError(TransportError, *PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint @@ -135,7 +135,7 @@ type RawTransportEndpoint interface { // this transport endpoint. The packet contains all data from the link // layer up. // - // HandlePacket takes ownership of the packet. + // HandlePacket may modify the packet. HandlePacket(*PacketBuffer) } @@ -153,7 +153,7 @@ type PacketEndpoint interface { // linkHeader may have a length of 0, in which case the PacketEndpoint // should construct its own ethernet header for applications. // - // HandlePacket takes ownership of pkt. + // HandlePacket may modify pkt. HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } @@ -202,7 +202,7 @@ type TransportProtocol interface { // protocol that don't match any existing endpoint. For example, // it is targeted at a port that has no listeners. // - // HandleUnknownDestinationPacket takes ownership of the packet if it handles + // HandleUnknownDestinationPacket may modify the packet if it handles // the issue. HandleUnknownDestinationPacket(TransportEndpointID, *PacketBuffer) UnknownDestinationPacketDisposition @@ -257,13 +257,13 @@ type TransportDispatcher interface { // // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // - // DeliverTransportPacket takes ownership of the packet. + // DeliverTransportPacket may modify the packet. DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition // DeliverTransportError delivers an error to the appropriate transport // endpoint. // - // DeliverTransportError takes ownership of the packet buffer. + // DeliverTransportError may modify the packet buffer. DeliverTransportError(local, remote tcpip.Address, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ TransportError, _ *PacketBuffer) // DeliverRawPacket delivers a packet to any subscribed raw sockets. @@ -570,14 +570,14 @@ type NetworkInterface interface { // WritePacket writes a packet with the given protocol through the given // route. // - // WritePacket takes ownership of the packet buffer. The packet buffer's + // WritePacket may modify the packet buffer. The packet buffer's // network and transport header must be set. WritePacket(*Route, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol through the given // route. Must not be called with an empty list of packet buffers. // - // WritePackets takes ownership of the packet buffers. + // WritePackets may modify the packet buffers. // // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters @@ -636,23 +636,23 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. It takes ownership of pkt. pkt.TransportHeader must have + // protocol. It may modify pkt. pkt.TransportHeader must have // already been set. WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error // WritePackets writes packets to the given destination address and - // protocol. pkts must not be zero length. It takes ownership of pkts and + // protocol. pkts must not be zero length. It may modify pkts and // underlying packets. WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network - // header to the given destination address. It takes ownership of pkt. + // header to the given destination address. It may modify pkt. WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error // HandlePacket is called by the link layer when new packets arrive to // this network endpoint. It sets pkt.NetworkHeader. // - // HandlePacket takes ownership of pkt. + // HandlePacket may modify pkt. HandlePacket(pkt *PacketBuffer) // Close is called when the endpoint is removed from a stack. @@ -748,7 +748,7 @@ type NetworkDispatcher interface { // DeliverNetworkPacket. Some packets do not have link headers (e.g. // packets sent via loopback), and won't have the field set. // - // DeliverNetworkPacket takes ownership of pkt. + // DeliverNetworkPacket may modify pkt. DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } @@ -836,7 +836,7 @@ type LinkEndpoint interface { // WritePacket writes a packet with the given protocol and route. // - // WritePacket takes ownership of the packet buffer. The packet buffer's + // WritePacket may modify the packet buffer. The packet buffer's // network and transport header must be set. // // To participate in transparent bridging, a LinkEndpoint implementation @@ -847,7 +847,7 @@ type LinkEndpoint interface { // WritePackets writes packets with the given protocol and route. Must not be // called with an empty list of packet buffers. // - // WritePackets takes ownership of the packet buffers. + // WritePackets may modify the packet buffers. // // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters @@ -859,7 +859,7 @@ type LinkEndpoint interface { // If the link-layer has its own header, the payload must already include the // header. // - // WriteRawPacket takes ownership of the packet. + // WriteRawPacket may modify the packet. WriteRawPacket(*PacketBuffer) tcpip.Error } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a05fd7036..3ddf9de6b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1622,6 +1622,7 @@ func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, ReserveHeaderBytes: int(nic.MaxHeaderLength()), Data: payload, }) + defer pkt.DecRef() pkt.NetworkProtocolNumber = netProto return nic.WritePacketToRemote(remote, netProto, pkt) } @@ -1639,6 +1640,7 @@ func (s *Stack) WriteRawPacket(nicID tcpip.NICID, proto tcpip.NetworkProtocolNum pkt := NewPacketBuffer(PacketBufferOptions{ Data: payload, }) + defer pkt.DecRef() pkt.NetworkProtocolNumber = proto return nic.WriteRawPacket(pkt) } diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go index e4752bab1..b6ecac560 100644 --- a/pkg/tcpip/stack/stack_state_autogen.go +++ b/pkg/tcpip/stack/stack_state_autogen.go @@ -460,6 +460,30 @@ func (e *PacketBufferEntry) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(1, &e.prev) } +func (r *packetBufferRefs) StateTypeName() string { + return "pkg/tcpip/stack.packetBufferRefs" +} + +func (r *packetBufferRefs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (r *packetBufferRefs) beforeSave() {} + +// +checklocksignore +func (r *packetBufferRefs) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.refCount) +} + +// +checklocksignore +func (r *packetBufferRefs) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.refCount) + stateSourceObject.AfterLoad(r.afterLoad) +} + func (t *TransportEndpointID) StateTypeName() string { return "pkg/tcpip/stack.TransportEndpointID" } @@ -1251,6 +1275,7 @@ func init() { state.Register((*neighborEntryEntry)(nil)) state.Register((*PacketBufferList)(nil)) state.Register((*PacketBufferEntry)(nil)) + state.Register((*packetBufferRefs)(nil)) state.Register((*TransportEndpointID)(nil)) state.Register((*GSOType)(nil)) state.Register((*GSO)(nil)) diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 3474c292a..088913b83 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -401,14 +401,16 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] - // HandlePacket takes ownership of pkt, so each endpoint needs + // HandlePacket may modify pkt, so each endpoint needs // its own copy except for the final one. for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] { + clone := pkt.Clone() if mustQueue { - queuedProtocol.QueuePacket(endpoint, id, pkt.Clone()) + queuedProtocol.QueuePacket(endpoint, id, clone) } else { - endpoint.HandlePacket(id, pkt.Clone()) + endpoint.HandlePacket(id, clone) } + clone.DecRef() } if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue { queuedProtocol.QueuePacket(endpoint, id, pkt) @@ -559,10 +561,12 @@ func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, d.stack.stats.UDP.UnknownPortErrors.Increment() return false } - // handlePacket takes ownership of pkt, so each endpoint needs its own + // handlePacket takes may modify pkt, so each endpoint needs its own // copy except for the final one. for _, ep := range destEPs[:len(destEPs)-1] { - ep.handlePacket(id, pkt.Clone()) + clone := pkt.Clone() + ep.handlePacket(id, clone) + clone.DecRef() } destEPs[len(destEPs)-1].handlePacket(id, pkt) return true @@ -615,7 +619,9 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb for _, rawEP := range rawEPs { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(pkt.Clone()) + clone := pkt.Clone() + rawEP.HandlePacket(clone) + clone.DecRef() } return len(rawEPs) != 0 |