diff options
Diffstat (limited to 'pkg/tcpip/stack/pending_packets.go')
-rw-r--r-- | pkg/tcpip/stack/pending_packets.go | 51 |
1 files changed, 42 insertions, 9 deletions
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 81d8ff6e8..3ac039c7d 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -28,10 +28,26 @@ const ( maxPendingPacketsPerResolution = 256 ) +// pendingPacketBuffer is a pending packet buffer. +// +// TODO(gvisor.dev/issue/5331): Drop this when we drop WritePacket and only use +// WritePackets so we can use a PacketBufferList everywhere. +type pendingPacketBuffer interface { + len() int +} + +func (*PacketBuffer) len() int { + return 1 +} + +func (p *PacketBufferList) len() int { + return p.Len() +} + type pendingPacket struct { route *Route proto tcpip.NetworkProtocolNumber - pkt *PacketBuffer + pkt pendingPacketBuffer } // packetsPendingLinkResolution is a queue of packets pending link resolution. @@ -54,16 +70,17 @@ func (f *packetsPendingLinkResolution) init() { f.packets = make(map[<-chan struct{}][]pendingPacket) } -func incrementOutgoingPacketErrors(r *Route, proto tcpip.NetworkProtocolNumber) { - r.Stats().IP.OutgoingPacketErrors.Increment() +func incrementOutgoingPacketErrors(r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) { + n := uint64(pkt.len()) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(n) // ok may be false if the endpoint's stats do not collect IP-related data. if ipEndpointStats, ok := r.outgoingNIC.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok { - ipEndpointStats.IPStats().OutgoingPacketErrors.Increment() + ipEndpointStats.IPStats().OutgoingPacketErrors.IncrementBy(n) } } -func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { +func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) { f.Lock() defer f.Unlock() @@ -73,7 +90,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro packets[0] = pendingPacket{} packets = packets[1:] - incrementOutgoingPacketErrors(r, proto) + incrementOutgoingPacketErrors(r, proto, p.pkt) p.route.Release() } @@ -113,13 +130,29 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro for _, p := range packets { if cancelled || p.route.IsResolutionRequired() { - incrementOutgoingPacketErrors(r, proto) + incrementOutgoingPacketErrors(r, proto, p.pkt) if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok { - linkResolvableEP.HandleLinkResolutionFailure(pkt) + switch pkt := p.pkt.(type) { + case *PacketBuffer: + linkResolvableEP.HandleLinkResolutionFailure(pkt) + case *PacketBufferList: + for pb := pkt.Front(); pb != nil; pb = pb.Next() { + linkResolvableEP.HandleLinkResolutionFailure(pb) + } + default: + panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt)) + } } } else { - p.route.outgoingNIC.writePacket(p.route.Fields(), nil /* gso */, p.proto, p.pkt) + switch pkt := p.pkt.(type) { + case *PacketBuffer: + p.route.outgoingNIC.writePacket(p.route.Fields(), nil /* gso */, p.proto, pkt) + case *PacketBufferList: + p.route.outgoingNIC.writePackets(p.route.Fields(), nil /* gso */, p.proto, *pkt) + default: + panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt)) + } } p.route.Release() } |