diff options
Diffstat (limited to 'pkg/tcpip')
34 files changed, 177 insertions, 44 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 7a159d5c8..b0c059bf1 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -58,6 +58,8 @@ deps_test( "//pkg/state/wire", "//pkg/sync", "//pkg/waiter", + "//pkg/refsvfs2", + "//pkg/refs", "//pkg/syserr", "//pkg/abi/linux/errno", "//pkg/errors", diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 658557d62..270fa8c79 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -81,6 +81,18 @@ func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) { } func (q *queue) Write(p PacketInfo) bool { + // q holds the PacketBuffer. + + // Ideally, Write() should take a reference here, since it is adding + // the underlying PacketBuffer to the channel. However, in practice, + // calls to Read() are not necessarily symetric with calls + // to Write() (e.g writing to this endpoint and then exiting). This + // causes tests and analyzers to detect erroneous "leaks" for expected + // behavior. To prevent this, we allow the refcount to go to zero, and + // make a call to PreserveObject(), which prevents the PacketBuffer + // pooling implementation from reclaiming this instance, even when + // the refcount goes to zero. + p.Pkt.PreserveObject() wrote := false select { case q.c <- p: diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 3f516cab5..47047578d 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -194,6 +194,7 @@ func (d *packetMMapDispatcher) dispatch() (bool, tcpip.Error) { pbuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(pkt).ToVectorisedView(), }) + defer pbuf.DecRef() if d.e.hdrSize > 0 { if _, ok := pbuf.LinkHeader().Consume(d.e.hdrSize); !ok { panic(fmt.Sprintf("LinkHeader().Consume(%d) must succeed", d.e.hdrSize)) diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index fab34c5fa..c22bba3b5 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -181,6 +181,7 @@ func (d *readVDispatcher) dispatch() (bool, tcpip.Error) { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: d.buf.pullViews(n), }) + defer pkt.DecRef() var ( p tcpip.NetworkProtocolNumber @@ -289,6 +290,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: d.bufs[k].pullViews(n), }) + defer pkt.DecRef() // Mark that this iovec has been processed. d.msgHdrs[k].Msg.Iovlen = 0 diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index ca1f9c08d..49b0a29a9 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -104,6 +104,7 @@ func (e *endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error { newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: data, }) + defer newPkt.DecRef() e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, pkt.NetworkProtocolNumber, newPkt) return nil diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index c67ca98ea..c2a888054 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -59,9 +59,11 @@ func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocol // avoid a deadlock when a packet triggers a response which leads the stack to // try and take a lock it already holds. for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - })) + }) + e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, newPkt) + newPkt.DecRef() } } diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index c15cbf81b..a68b274b2 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -94,9 +94,7 @@ func (q *queueDispatcher) dispatchLoop() { // We pass a protocol of zero here because each packet carries its // NetworkProtocol. q.lower.WritePackets(stack.RouteInfo{}, batch, 0 /* protocol */) - for pkt := batch.Front(); pkt != nil; pkt = pkt.Next() { - batch.Remove(pkt) - } + batch.DecRef() batch.Reset() } } diff --git a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go index eb5abb906..09e5b8314 100644 --- a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go +++ b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go @@ -54,13 +54,13 @@ func (q *packetBufferQueue) setLimit(limit int) { // enqueue adds the given packet to the queue. // // Returns true when the PacketBuffer is successfully added to the queue, in -// which case ownership of the reference is transferred to the queue. And -// returns false if the queue is full, in which case ownership is retained by -// the caller. +// which case the queue acquires a reference to the PacketBuffer, and +// returns false if the queue is full. func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool { q.mu.Lock() r := q.used < q.limit if r { + s.IncRef() q.list.PushBack(s) q.used++ } @@ -70,7 +70,7 @@ func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool { } // dequeue removes and returns the next PacketBuffer from queue, if one exists. -// Ownership is transferred to the caller. +// Caller is responsible for calling DecRef on the PacketBuffer. func (q *packetBufferQueue) dequeue() *stack.PacketBuffer { q.mu.Lock() s := q.list.Front() diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index b75522a51..8797d1bb9 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -413,6 +413,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(b).ToVectorisedView(), }) + defer pkt.DecRef() var src, dst tcpip.LinkAddress var proto tcpip.NetworkProtocolNumber diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server.go b/pkg/tcpip/link/sharedmem/sharedmem_server.go index c39eca33f..00c8a6a3b 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_server.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_server.go @@ -311,6 +311,7 @@ func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) { if e.addr != "" { hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) if !ok { + pkt.DecRef() continue } eth := header.Ethernet(hdr) @@ -323,6 +324,7 @@ func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) { // IP version information is at the first octet, so pulling up 1 byte. h, ok := pkt.Data().PullUp(1) if !ok { + pkt.DecRef() continue } switch header.IPVersion(h) { @@ -331,11 +333,13 @@ func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) { case header.IPv6Version: proto = header.IPv6ProtocolNumber default: + pkt.DecRef() continue } } // Send packet up the stack. d.DeliverNetworkPacket(src, dst, proto, pkt) + pkt.DecRef() } e.mu.Lock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 2afa95af0..965cc994f 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -209,6 +209,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) vv.TrimFront(len(pkt.LinkHeader().View())) pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) + defer pkt.DecRef() switch protocol { case header.IPv4ProtocolNumber: if ok := parse.IPv4(pkt); !ok { diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index fa2131c28..5230ac281 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -76,6 +76,7 @@ func (d *Device) Release(ctx context.Context) { // Decrease refcount if there is an endpoint associated with this file. if d.endpoint != nil { + d.endpoint.Drain() d.endpoint.RemoveNotify(d.notifyHandle) d.endpoint.DecRef(ctx) d.endpoint = nil @@ -231,6 +232,7 @@ func (d *Device) Write(data []byte) (int64, error) { ReserveHeaderBytes: len(ethHdr), Data: buffer.View(data).ToVectorisedView(), }) + defer pkt.DecRef() copy(pkt.LinkHeader().Push(len(ethHdr)), ethHdr) endpoint.InjectLinkAddr(protocol, remote, pkt) return dataLen, nil diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index e08243547..8e6be7e26 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -198,6 +198,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) + defer respPkt.DecRef() packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize)) respPkt.NetworkProtocolNumber = ProtocolNumber packet.SetIPv4OverEthernet() @@ -339,6 +340,7 @@ func (e *endpoint) sendARPRequest(localAddr, targetAddr tcpip.Address, remoteLin pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(e.MaxHeaderLength()), }) + defer pkt.DecRef() h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) pkt.NetworkProtocolNumber = ProtocolNumber h.SetIPv4OverEthernet() diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go index 5b7e4b361..ff6be8f0d 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go @@ -149,6 +149,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s r.proto = proto } + pkt.IncRef() break } if !holeFound { @@ -166,6 +167,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s }) resPkt := r.holes[0].pkt + resPkt.DecRef() for i := 1; i < len(r.holes); i++ { stack.MergeFragment(resPkt, r.holes[i].pkt) } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 59acbad02..33a1b837e 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -239,7 +239,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4Echo: received.echoRequest.Increment() - // DeliverTransportPacket will take ownership of pkt so don't use it beyond + // DeliverTransportPacket may modify pkt so don't use it beyond // this point. Make a deep copy of the data before pkt gets sent as we will // be modifying fields. Both the ICMP header (with its type modified to // EchoReply) and payload are reused in the reply packet. @@ -320,6 +320,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: replyVV, }) + defer replyPkt.DecRef() replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil { @@ -667,6 +668,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize, Data: payload, }) + defer icmpPkt.DecRef() icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index 3ce499298..d9cc4574e 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -322,6 +322,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip ReserveHeaderBytes: int(igmp.ep.MaxHeaderLength()), Data: buffer.View(igmpData).ToVectorisedView(), }) + defer pkt.DecRef() addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, false /* allowExpired */) if addressEndpoint == nil { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index d1d509702..e55f5eea6 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -119,6 +119,7 @@ func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), }) + defer pkt.DecRef() pkt.NICID = e.nic.ID() pkt.NetworkProtocolNumber = ProtocolNumber // Use the same control type as an ICMPv4 destination host unreachable error @@ -534,6 +535,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // removed once the fragmentation is done. originalPkt := pkt if _, _, err := e.handleFragments(r, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { + fragPkt.IncRef() // Modify the packet list in place with the new fragments. pkts.InsertAfter(pkt, fragPkt) pkt = fragPkt @@ -751,10 +753,11 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { } // We need to do a deep copy of the IP packet because - // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do + // WriteHeaderIncludedPacket may modify the packet buffer, but we do // not own it. newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength())) newHdr := header.IPv4(newPkt.NetworkHeader().View()) + defer newPkt.DecRef() // As per RFC 791 page 30, Time to Live, // @@ -859,6 +862,7 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum stats.PacketsReceived.Increment() pkt = pkt.CloneToInbound() + defer pkt.DecRef() pkt.RXTransportChecksumValidated = canSkipRXChecksum h, ok := e.protocol.parseAndValidate(pkt) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index adfc8d8da..402c4c8a8 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -525,6 +525,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborAdvertSize, }) + defer pkt.DecRef() pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize)) packet.SetType(header.ICMPv6NeighborAdvert) @@ -675,6 +676,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, Data: pkt.Data().ExtractVV(), }) + defer replyPkt.DecRef() icmp := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber copy(icmp, h) @@ -1213,6 +1215,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize, Data: payload, }) + defer newPkt.DecRef() newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 7d3e1fd53..0c8ff6fb9 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -296,6 +296,7 @@ func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), }) + defer pkt.DecRef() pkt.NICID = e.nic.ID() pkt.NetworkProtocolNumber = ProtocolNumber e.handleControl(&icmpv6DestinationAddressUnreachableSockError{}, pkt) @@ -855,6 +856,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // removed once the fragmentation is done. originalPkt := pb if _, _, err := e.handleFragments(r, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { + fragPkt.IncRef() // Modify the packet list in place with the new fragments. pkts.InsertAfter(pb, fragPkt) pb = fragPkt @@ -1025,6 +1027,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do // not own it. newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength())) + defer newPkt.DecRef() newHdr := header.IPv6(newPkt.NetworkHeader().View()) // As per RFC 8200 section 3, @@ -1118,6 +1121,7 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum stats.PacketsReceived.Increment() pkt = pkt.CloneToInbound() + defer pkt.DecRef() pkt.RXTransportChecksumValidated = canSkipRXChecksum h, ok := e.protocol.parseAndValidate(pkt) diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index bc1af193c..06a8e1b89 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -270,6 +270,7 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()) + extensionHeaders.Length(), Data: buffer.View(icmp).ToVectorisedView(), }) + defer pkt.DecRef() if err := addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 938427420..bebf72421 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -1807,6 +1807,7 @@ func (ndp *ndpState) startSolicitingRouters() { ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()), Data: buffer.View(icmpData).ToVectorisedView(), }) + defer pkt.DecRef() sent := ndp.ep.stats.icmp.packetsSent if err := addIPHeader(localAddr, header.IPv6AllRoutersLinkLocalMulticastAddress, pkt, stack.NetworkHeaderParams{ @@ -1924,6 +1925,7 @@ func (e *endpoint) sendNDPNS(srcAddr, dstAddr, targetAddr tcpip.Address, remoteL ReserveHeaderBytes: int(e.MaxHeaderLength()), Data: buffer.View(icmp).ToVectorisedView(), }) + defer pkt.DecRef() if err := addIPHeader(srcAddr, dstAddr, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 5d76adac1..81eed5b11 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -39,6 +39,17 @@ go_template_instance( }, ) +go_template_instance( + name = "packet_buffer_refs", + out = "packet_buffer_refs.go", + package = "stack", + prefix = "packetBuffer", + template = "//pkg/refsvfs2:refs_template", + types = { + "T": "PacketBuffer", + }, +) + go_library( name = "stack", srcs = [ @@ -59,6 +70,7 @@ go_library( "nud.go", "packet_buffer.go", "packet_buffer_list.go", + "packet_buffer_refs.go", "packet_buffer_unsafe.go", "pending_packets.go", "rand.go", @@ -78,6 +90,7 @@ go_library( "//pkg/ilist", "//pkg/log", "//pkg/rand", + "//pkg/refsvfs2", "//pkg/sleep", "//pkg/sync", "//pkg/tcpip", diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index b9b5c35c8..7cfb836ca 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -372,7 +372,7 @@ func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, protocol tcp } func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { - // WritePacket takes ownership of pkt, calculate numBytes first. + // WritePacket modifies pkt, calculate numBytes first. numBytes := pkt.Size() pkt.EgressRoute = r @@ -754,6 +754,7 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp packetEPPkt = NewPacketBuffer(PacketBufferOptions{ Data: PayloadSince(pkt.LinkHeader()).ToVectorisedView(), }) + defer packetEPPkt.DecRef() // If a link header was populated in the original packet buffer, then // populate it in the packet buffer we provide to packet endpoints as // packet endpoints inspect link headers. @@ -761,7 +762,9 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp packetEPPkt.PktType = tcpip.PacketHost } - ep.HandlePacket(n.id, local, protocol, packetEPPkt.Clone()) + clone := packetEPPkt.Clone() + defer clone.DecRef() + ep.HandlePacket(n.id, local, protocol, clone) } n.packetEPs.mu.Lock() @@ -811,14 +814,16 @@ func (n *nic) deliverOutboundPacket(remote tcpip.LinkAddress, pkt *PacketBuffer) ReserveHeaderBytes: pkt.AvailableHeaderBytes(), Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), }) + defer packetEPPkt.DecRef() // Add the link layer header as outgoing packets are intercepted before // the link layer header is created and packet endpoints are interested // in the link header. n.LinkEndpoint.AddHeader(local, remote, pkt.NetworkProtocolNumber, packetEPPkt) packetEPPkt.PktType = tcpip.PacketOutgoing } - - ep.HandlePacket(n.id, local, pkt.NetworkProtocolNumber, packetEPPkt.Clone()) + clone := packetEPPkt.Clone() + defer clone.DecRef() + ep.HandlePacket(n.id, local, pkt.NetworkProtocolNumber, clone) }) } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index c4a4bbd22..2016f7b19 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -88,6 +88,8 @@ type PacketBufferOptions struct { type PacketBuffer struct { _ sync.NoCopy + packetBufferRefs + // PacketBufferEntry is used to build an intrusive list of // PacketBuffers. PacketBufferEntry @@ -149,6 +151,8 @@ type PacketBuffer struct { NetworkPacketInfo NetworkPacketInfo tuple *tuple + + preserveObject bool } // NewPacketBuffer creates a new PacketBuffer with opts. @@ -166,9 +170,21 @@ func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { if opts.IsForwardedPacket { pk.NetworkPacketInfo.IsForwardedPacket = opts.IsForwardedPacket } + pk.InitRefs() return pk } +// DecRef overrides refsvfs2 DecRef and passes a nil destroy function. +func (pk *PacketBuffer) DecRef() { + pk.packetBufferRefs.DecRef(nil) +} + +// PreserveObject marks this PacketBuffer so it is not recycled by internal +// pooling. +func (pk *PacketBuffer) PreserveObject() { + pk.preserveObject = true +} + // ReservedHeaderBytes returns the number of bytes initially reserved for // headers. func (pk *PacketBuffer) ReservedHeaderBytes() int { @@ -291,7 +307,7 @@ func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View { // Clone makes a semi-deep copy of pk. The underlying packet payload is // shared. Hence, no modifications is done to underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { - return &PacketBuffer{ + newPk := &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, buf: pk.buf.Clone(), reserved: pk.reserved, @@ -311,6 +327,8 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { NetworkPacketInfo: pk.NetworkPacketInfo, tuple: pk.tuple, } + newPk.InitRefs() + return newPk } // Network returns the network header as a header.Network. @@ -339,6 +357,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { reserved: pk.AvailableHeaderBytes(), tuple: pk.tuple, } + newPk.InitRefs() return newPk } @@ -375,6 +394,22 @@ func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBu return newPk } +// IncRef increases the reference count on each PacketBuffer +// stored in the PacketBufferList. +func (pk *PacketBufferList) IncRef() { + for pb := pk.Front(); pb != nil; pb = pb.Next() { + pb.IncRef() + } +} + +// DecRef decreases the reference count on each PacketBuffer +// stored in the PacketBufferList. +func (pk *PacketBufferList) DecRef() { + for pb := pk.Front(); pb != nil; pb = pb.Next() { + pb.DecRef() + } +} + // headerInfo stores metadata about a header in a packet. type headerInfo struct { // offset is the offset of the header in pk.buf relative to @@ -460,7 +495,7 @@ func (d PacketData) AppendView(v tcpipbuffer.View) { d.pk.buf.AppendOwned(v) } -// MergeFragment appends the data portion of frag to dst. It takes ownership of +// MergeFragment appends the data portion of frag to dst. It modifies // frag and frag should not be used again. func MergeFragment(dst, frag *PacketBuffer) { frag.buf.TrimFront(int64(frag.dataOffset())) diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 13e8907ec..7e18d4bc4 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -152,6 +152,12 @@ func (f *packetsPendingLinkResolution) enqueue(r *Route, proto tcpip.NetworkProt proto: proto, pkt: pkt, }) + switch pkt := pkt.(type) { + case *PacketBuffer: + pkt.IncRef() + case *PacketBufferList: + pkt.IncRef() + } if len(packets) > maxPendingPacketsPerResolution { f.incrementOutgoingPacketErrors(packets[0].proto, packets[0].pkt) @@ -226,5 +232,11 @@ func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, l } } } + switch pkt := p.pkt.(type) { + case *PacketBuffer: + pkt.DecRef() + case *PacketBufferList: + pkt.DecRef() + } } } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 31b3a554d..5db9ad1b1 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -102,12 +102,12 @@ type TransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to this // transport endpoint. It sets the packet buffer's transport header. // - // HandlePacket takes ownership of the packet. + // HandlePacket may modify the packet. HandlePacket(TransportEndpointID, *PacketBuffer) // HandleError is called when the transport endpoint receives an error. // - // HandleError takes ownership of the packet buffer. + // HandleError takes may modify the packet buffer. HandleError(TransportError, *PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint @@ -135,7 +135,7 @@ type RawTransportEndpoint interface { // this transport endpoint. The packet contains all data from the link // layer up. // - // HandlePacket takes ownership of the packet. + // HandlePacket may modify the packet. HandlePacket(*PacketBuffer) } @@ -153,7 +153,7 @@ type PacketEndpoint interface { // linkHeader may have a length of 0, in which case the PacketEndpoint // should construct its own ethernet header for applications. // - // HandlePacket takes ownership of pkt. + // HandlePacket may modify pkt. HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } @@ -202,7 +202,7 @@ type TransportProtocol interface { // protocol that don't match any existing endpoint. For example, // it is targeted at a port that has no listeners. // - // HandleUnknownDestinationPacket takes ownership of the packet if it handles + // HandleUnknownDestinationPacket may modify the packet if it handles // the issue. HandleUnknownDestinationPacket(TransportEndpointID, *PacketBuffer) UnknownDestinationPacketDisposition @@ -257,13 +257,13 @@ type TransportDispatcher interface { // // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // - // DeliverTransportPacket takes ownership of the packet. + // DeliverTransportPacket may modify the packet. DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition // DeliverTransportError delivers an error to the appropriate transport // endpoint. // - // DeliverTransportError takes ownership of the packet buffer. + // DeliverTransportError may modify the packet buffer. DeliverTransportError(local, remote tcpip.Address, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ TransportError, _ *PacketBuffer) // DeliverRawPacket delivers a packet to any subscribed raw sockets. @@ -570,14 +570,14 @@ type NetworkInterface interface { // WritePacket writes a packet with the given protocol through the given // route. // - // WritePacket takes ownership of the packet buffer. The packet buffer's + // WritePacket may modify the packet buffer. The packet buffer's // network and transport header must be set. WritePacket(*Route, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol through the given // route. Must not be called with an empty list of packet buffers. // - // WritePackets takes ownership of the packet buffers. + // WritePackets may modify the packet buffers. // // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters @@ -636,23 +636,23 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. It takes ownership of pkt. pkt.TransportHeader must have + // protocol. It may modify pkt. pkt.TransportHeader must have // already been set. WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error // WritePackets writes packets to the given destination address and - // protocol. pkts must not be zero length. It takes ownership of pkts and + // protocol. pkts must not be zero length. It may modify pkts and // underlying packets. WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network - // header to the given destination address. It takes ownership of pkt. + // header to the given destination address. It may modify pkt. WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error // HandlePacket is called by the link layer when new packets arrive to // this network endpoint. It sets pkt.NetworkHeader. // - // HandlePacket takes ownership of pkt. + // HandlePacket may modify pkt. HandlePacket(pkt *PacketBuffer) // Close is called when the endpoint is removed from a stack. @@ -748,7 +748,7 @@ type NetworkDispatcher interface { // DeliverNetworkPacket. Some packets do not have link headers (e.g. // packets sent via loopback), and won't have the field set. // - // DeliverNetworkPacket takes ownership of pkt. + // DeliverNetworkPacket may modify pkt. DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } @@ -836,7 +836,7 @@ type LinkEndpoint interface { // WritePacket writes a packet with the given protocol and route. // - // WritePacket takes ownership of the packet buffer. The packet buffer's + // WritePacket may modify the packet buffer. The packet buffer's // network and transport header must be set. // // To participate in transparent bridging, a LinkEndpoint implementation @@ -847,7 +847,7 @@ type LinkEndpoint interface { // WritePackets writes packets with the given protocol and route. Must not be // called with an empty list of packet buffers. // - // WritePackets takes ownership of the packet buffers. + // WritePackets may modify the packet buffers. // // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters @@ -859,7 +859,7 @@ type LinkEndpoint interface { // If the link-layer has its own header, the payload must already include the // header. // - // WriteRawPacket takes ownership of the packet. + // WriteRawPacket may modify the packet. WriteRawPacket(*PacketBuffer) tcpip.Error } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a05fd7036..3ddf9de6b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1622,6 +1622,7 @@ func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, ReserveHeaderBytes: int(nic.MaxHeaderLength()), Data: payload, }) + defer pkt.DecRef() pkt.NetworkProtocolNumber = netProto return nic.WritePacketToRemote(remote, netProto, pkt) } @@ -1639,6 +1640,7 @@ func (s *Stack) WriteRawPacket(nicID tcpip.NICID, proto tcpip.NetworkProtocolNum pkt := NewPacketBuffer(PacketBufferOptions{ Data: payload, }) + defer pkt.DecRef() pkt.NetworkProtocolNumber = proto return nic.WriteRawPacket(pkt) } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 3474c292a..088913b83 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -401,14 +401,16 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] - // HandlePacket takes ownership of pkt, so each endpoint needs + // HandlePacket may modify pkt, so each endpoint needs // its own copy except for the final one. for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] { + clone := pkt.Clone() if mustQueue { - queuedProtocol.QueuePacket(endpoint, id, pkt.Clone()) + queuedProtocol.QueuePacket(endpoint, id, clone) } else { - endpoint.HandlePacket(id, pkt.Clone()) + endpoint.HandlePacket(id, clone) } + clone.DecRef() } if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue { queuedProtocol.QueuePacket(endpoint, id, pkt) @@ -559,10 +561,12 @@ func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, d.stack.stats.UDP.UnknownPortErrors.Increment() return false } - // handlePacket takes ownership of pkt, so each endpoint needs its own + // handlePacket takes may modify pkt, so each endpoint needs its own // copy except for the final one. for _, ep := range destEPs[:len(destEPs)-1] { - ep.handlePacket(id, pkt.Clone()) + clone := pkt.Clone() + ep.handlePacket(id, clone) + clone.DecRef() } destEPs[len(destEPs)-1].handlePacket(id, pkt) return true @@ -615,7 +619,9 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb for _, rawEP := range rawEPs { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(pkt.Clone()) + clone := pkt.Clone() + rawEP.HandlePacket(clone) + clone.DecRef() } return len(rawEPs) != 0 diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index 654309584..bf6c69e3b 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -370,9 +370,11 @@ func rxICMPv4Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty hea }) ip.SetChecksum(^ip.CalculateChecksum()) - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - })) + }) + defer newPkt.DecRef() + e.InjectInbound(header.IPv4ProtocolNumber, newPkt) } // RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on @@ -408,9 +410,11 @@ func rxICMPv6Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty hea DstAddr: dst, }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - })) + }) + defer newPkt.DecRef() + e.InjectInbound(header.IPv6ProtocolNumber, newPkt) } // RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 995f58616..249588241 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -354,6 +354,7 @@ func send4(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: header.ICMPv4MinimumSize + int(maxHeaderLength), }) + defer pkt.DecRef() icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber @@ -394,6 +395,7 @@ func send6(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: header.ICMPv6MinimumSize + int(maxHeaderLength), }) + defer pkt.DecRef() icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index ce76774af..37170cd7d 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -286,6 +286,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp ReserveHeaderBytes: int(ctx.PacketInfo().MaxHeaderLength), Data: buffer.View(payloadBytes).ToVectorisedView(), }) + defer pkt.DecRef() if err := ctx.WritePacket(pkt, e.ops.GetHeaderIncluded()); err != nil { return 0, err diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 12df7a7b4..b1ab46c97 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -850,6 +850,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso pkt.GSOOptions = gso pkts.PushBack(pkt) } + defer pkts.DecRef() if tf.ttl == 0 { tf.ttl = r.DefaultTTL() @@ -878,6 +879,7 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso stack ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen, Data: data, }) + defer pkt.DecRef() pkt.GSOOptions = gso pkt.Hash = tf.txHash pkt.Owner = owner diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 88bb99354..fd0eca4bd 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -421,6 +421,7 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), }) + defer pkt.DecRef() c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } @@ -477,6 +478,7 @@ func (c *Context) SendSegment(s buffer.VectorisedView) { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: s, }) + defer pkt.DecRef() c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } @@ -486,6 +488,7 @@ func (c *Context) SendPacket(payload []byte, h *Headers) { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: c.BuildSegment(payload, h), }) + defer pkt.DecRef() c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 077a2325a..4f5d0bb0e 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -440,6 +440,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp ReserveHeaderBytes: header.UDPMinimumSize + int(pktInfo.MaxHeaderLength), Data: udpInfo.data.ToVectorisedView(), }) + defer pkt.DecRef() // Initialize the UDP header. udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) |