summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
authorTing-Yu Wang <anivia@google.com>2021-03-03 16:03:04 -0800
committergVisor bot <gvisor-bot@google.com>2021-03-03 16:05:16 -0800
commit1cd76d958a9b3eb29f6b55a8bea71fbe464e67d3 (patch)
tree1f4df3b516c62a2aa630ffaf9c6ecba99482e3d3 /pkg/tcpip/stack
parentcfd2c31962a4358d7d05a4bd04dde271dc238339 (diff)
Make dedicated methods for data operations in PacketBuffer
One of the preparation to decouple underlying buffer implementation. There are still some methods that tie to VectorisedView, and they will be changed gradually in later CLs. This CL also introduce a new ICMPv6ChecksumParams to replace long list of parameters when calling ICMPv6Checksum, aiming to be more descriptive. PiperOrigin-RevId: 360778149
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/conntrack.go4
-rw-r--r--pkg/tcpip/stack/iptables_targets.go2
-rw-r--r--pkg/tcpip/stack/ndp_test.go18
-rw-r--r--pkg/tcpip/stack/nic.go6
-rw-r--r--pkg/tcpip/stack/packet_buffer.go248
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go224
-rw-r--r--pkg/tcpip/stack/stack_test.go6
7 files changed, 482 insertions, 26 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index cdb435644..3f083928f 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -407,12 +407,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- length := uint16(len(tcpHeader) + pkt.Data.Size())
+ length := uint16(len(tcpHeader) + pkt.Data().Size())
xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
} else if r.RequiresTXTransportChecksum() {
- xsum = header.ChecksumVV(pkt.Data, xsum)
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index d63e9757c..0e8b90c9b 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -153,7 +153,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
if r.RequiresTXTransportChecksum() {
length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- xsum = header.ChecksumVV(pkt.Data, xsum)
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 740bdac28..78a4cb072 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -573,7 +573,11 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) {
ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(tgt)
snmc := header.SolicitedNodeAddr(tgt)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: header.IPv6Any,
+ Dst: snmc,
+ }))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -619,7 +623,11 @@ func TestDADFail(t *testing.T) {
na.Options().Serialize(header.NDPOptionsSerializer{
header.NDPTargetLinkLayerAddressOption(linkAddr1),
})
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: tgt,
+ Dst: header.IPv6AllNodesMulticastAddress,
+ }))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -973,7 +981,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
}
opts := ra.Options()
opts.Serialize(optSer)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: ip,
+ Dst: header.IPv6AllNodesMulticastAddress,
+ }))
payloadLength := hdr.UsedLength()
iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
iph.Encode(&header.IPv6Fields{
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index f9323d545..62f7c880e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -725,12 +725,12 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n.mu.RUnlock()
n.stats.DisabledRx.Packets.Increment()
- n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
+ n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size()))
return
}
n.stats.Rx.Packets.Increment()
- n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
+ n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size()))
networkEndpoint, ok := n.networkEndpoints[protocol]
if !ok {
@@ -881,7 +881,7 @@ func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo
// ICMPv4 only guarantees that 8 bytes of the transport protocol will
// be present in the payload. We know that the ports are within the
// first 8 bytes for all known transport protocols.
- transHeader, ok := pkt.Data.PullUp(8)
+ transHeader, ok := pkt.Data().PullUp(8)
if !ok {
return
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 4f013b212..8f288675d 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -59,7 +59,7 @@ type PacketBuffer struct {
// PacketBuffers.
PacketBufferEntry
- // Data holds the payload of the packet.
+ // data holds the payload of the packet.
//
// For inbound packets, Data is initially the whole packet. Then gets moved to
// headers via PacketHeader.Consume, when the packet is being parsed.
@@ -69,7 +69,7 @@ type PacketBuffer struct {
//
// The bytes backing Data are immutable, a.k.a. users shouldn't write to its
// backing storage.
- Data buffer.VectorisedView
+ data buffer.VectorisedView
// headers stores metadata about each header.
headers [numHeaderType]headerInfo
@@ -127,7 +127,7 @@ type PacketBuffer struct {
// NewPacketBuffer creates a new PacketBuffer with opts.
func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer {
pk := &PacketBuffer{
- Data: opts.Data,
+ data: opts.Data,
}
if opts.ReserveHeaderBytes != 0 {
pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes)
@@ -184,13 +184,18 @@ func (pk *PacketBuffer) HeaderSize() int {
// Size returns the size of packet in bytes.
func (pk *PacketBuffer) Size() int {
- return pk.HeaderSize() + pk.Data.Size()
+ return pk.HeaderSize() + pk.data.Size()
}
// MemSize returns the estimation size of the pk in memory, including backing
// buffer data.
func (pk *PacketBuffer) MemSize() int {
- return pk.HeaderSize() + pk.Data.MemSize() + packetBufferStructSize
+ return pk.HeaderSize() + pk.data.MemSize() + packetBufferStructSize
+}
+
+// Data returns the handle to data portion of pk.
+func (pk *PacketBuffer) Data() PacketData {
+ return PacketData{pk: pk}
}
// Views returns the underlying storage of the whole packet.
@@ -204,7 +209,7 @@ func (pk *PacketBuffer) Views() []buffer.View {
}
}
- dataViews := pk.Data.Views()
+ dataViews := pk.data.Views()
var vs []buffer.View
if useHeader {
@@ -242,11 +247,11 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum
if h.buf != nil {
panic(fmt.Sprintf("consume must not be called twice: type %s", typ))
}
- v, ok := pk.Data.PullUp(size)
+ v, ok := pk.data.PullUp(size)
if !ok {
return
}
- pk.Data.TrimFront(size)
+ pk.data.TrimFront(size)
h.buf = v
return h.buf, true
}
@@ -258,7 +263,7 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum
func (pk *PacketBuffer) Clone() *PacketBuffer {
return &PacketBuffer{
PacketBufferEntry: pk.PacketBufferEntry,
- Data: pk.Data.Clone(nil),
+ data: pk.data.Clone(nil),
headers: pk.headers,
header: pk.header,
Hash: pk.Hash,
@@ -339,13 +344,234 @@ func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) {
return h.pk.consume(h.typ, size)
}
+// PacketData represents the data portion of a PacketBuffer.
+type PacketData struct {
+ pk *PacketBuffer
+}
+
+// PullUp returns a contiguous view of size bytes from the beginning of d.
+// Callers should not write to or keep the view for later use.
+func (d PacketData) PullUp(size int) (buffer.View, bool) {
+ return d.pk.data.PullUp(size)
+}
+
+// TrimFront removes count from the beginning of d. It panics if count >
+// d.Size().
+func (d PacketData) TrimFront(count int) {
+ d.pk.data.TrimFront(count)
+}
+
+// CapLength reduces d to at most length bytes.
+func (d PacketData) CapLength(length int) {
+ d.pk.data.CapLength(length)
+}
+
+// Views returns the underlying storage of d in a slice of Views. Caller should
+// not modify the returned slice.
+func (d PacketData) Views() []buffer.View {
+ return d.pk.data.Views()
+}
+
+// AppendView appends v into d, taking the ownership of v.
+func (d PacketData) AppendView(v buffer.View) {
+ d.pk.data.AppendView(v)
+}
+
+// ReadFromData moves at most count bytes from the beginning of srcData to the
+// end of d and returns the number of bytes moved.
+func (d PacketData) ReadFromData(srcData PacketData, count int) int {
+ return srcData.pk.data.ReadToVV(&d.pk.data, count)
+}
+
+// ReadFromVV moves at most count bytes from the beginning of srcVV to the end
+// of d and returns the number of bytes moved.
+func (d PacketData) ReadFromVV(srcVV *buffer.VectorisedView, count int) int {
+ return srcVV.ReadToVV(&d.pk.data, count)
+}
+
+// Size returns the number of bytes in the data payload of the packet.
+func (d PacketData) Size() int {
+ return d.pk.data.Size()
+}
+
+// AsRange returns a Range representing the current data payload of the packet.
+func (d PacketData) AsRange() Range {
+ return Range{
+ pk: d.pk,
+ offset: d.pk.HeaderSize(),
+ length: d.Size(),
+ }
+}
+
+// ExtractVV returns a VectorisedView of d. This method has the semantic to
+// destruct the underlying packet, hence the packet cannot be used again.
+//
+// This method exists for compatibility between PacketBuffer and VectorisedView.
+// It may be removed later and should be used with care.
+func (d PacketData) ExtractVV() buffer.VectorisedView {
+ return d.pk.data
+}
+
+// Replace replaces the data portion of the packet with vv, taking the ownership
+// of vv.
+//
+// This method exists for compatibility between PacketBuffer and VectorisedView.
+// It may be removed later and should be used with care.
+func (d PacketData) Replace(vv buffer.VectorisedView) {
+ d.pk.data = vv
+}
+
+// Range represents a contiguous subportion of a PacketBuffer.
+type Range struct {
+ pk *PacketBuffer
+ offset int
+ length int
+}
+
+// Size returns the number of bytes in r.
+func (r Range) Size() int {
+ return r.length
+}
+
+// SubRange returns a new Range starting at off bytes of r. It returns an empty
+// range if off is out-of-bounds.
+func (r Range) SubRange(off int) Range {
+ if off > r.length {
+ return Range{pk: r.pk}
+ }
+ return Range{
+ pk: r.pk,
+ offset: r.offset + off,
+ length: r.length - off,
+ }
+}
+
+// Capped returns a new Range with the same starting point of r and length
+// capped at max.
+func (r Range) Capped(max int) Range {
+ if r.length <= max {
+ return r
+ }
+ return Range{
+ pk: r.pk,
+ offset: r.offset,
+ length: max,
+ }
+}
+
+// AsView returns the backing storage of r if possible. It will allocate a new
+// View if r spans multiple pieces internally. Caller should not write to the
+// returned View in any way.
+func (r Range) AsView() buffer.View {
+ var allocated bool
+ var v buffer.View
+ r.iterate(func(b []byte) {
+ if v == nil {
+ // v has not been assigned, allowing first view to be returned.
+ v = b
+ } else {
+ // v has been assigned. This range spans more than a view, a new view
+ // needs to be allocated.
+ if !allocated {
+ allocated = true
+ all := make([]byte, 0, r.length)
+ all = append(all, v...)
+ v = all
+ }
+ v = append(v, b...)
+ }
+ })
+ return v
+}
+
+// ToOwnedView returns a owned copy of data in r.
+func (r Range) ToOwnedView() buffer.View {
+ if r.length == 0 {
+ return nil
+ }
+ all := make([]byte, 0, r.length)
+ r.iterate(func(b []byte) {
+ all = append(all, b...)
+ })
+ return all
+}
+
+// Checksum calculates the RFC 1071 checksum for the underlying bytes of r.
+func (r Range) Checksum() uint16 {
+ var c header.Checksumer
+ r.iterate(c.Add)
+ return c.Checksum()
+}
+
+// iterate calls fn for each piece in r. fn is always called with a non-empty
+// slice.
+func (r Range) iterate(fn func([]byte)) {
+ w := window{
+ offset: r.offset,
+ length: r.length,
+ }
+ // Header portion.
+ for i := range r.pk.headers {
+ if b := w.process(r.pk.headers[i].buf); len(b) > 0 {
+ fn(b)
+ }
+ if w.isDone() {
+ break
+ }
+ }
+ // Data portion.
+ if !w.isDone() {
+ for _, v := range r.pk.data.Views() {
+ if b := w.process(v); len(b) > 0 {
+ fn(b)
+ }
+ if w.isDone() {
+ break
+ }
+ }
+ }
+}
+
+// window represents contiguous region of byte stream. User would call process()
+// to input bytes, and obtain a subslice that is inside the window.
+type window struct {
+ offset int
+ length int
+}
+
+// isDone returns true if the window has passed and further process() calls will
+// always return an empty slice. This can be used to end processing early.
+func (w *window) isDone() bool {
+ return w.length == 0
+}
+
+// process feeds b in and returns a subslice that is inside the window. The
+// returned slice will be a subslice of b, and it does not keep b after method
+// returns. This method may return an empty slice if nothing in b is inside the
+// window.
+func (w *window) process(b []byte) (inWindow []byte) {
+ if w.offset >= len(b) {
+ w.offset -= len(b)
+ return nil
+ }
+ if w.offset > 0 {
+ b = b[w.offset:]
+ w.offset = 0
+ }
+ if w.length < len(b) {
+ b = b[:w.length]
+ }
+ w.length -= len(b)
+ return b
+}
+
// PayloadSince returns packet payload starting from and including a particular
// header.
//
// The returned View is owned by the caller - its backing buffer is separate
// from the packet header's underlying packet buffer.
func PayloadSince(h PacketHeader) buffer.View {
- size := h.pk.Data.Size()
+ size := h.pk.data.Size()
for _, hinfo := range h.pk.headers[h.typ:] {
size += len(hinfo.buf)
}
@@ -356,7 +582,7 @@ func PayloadSince(h PacketHeader) buffer.View {
v = append(v, hinfo.buf...)
}
- for _, view := range h.pk.Data.Views() {
+ for _, view := range h.pk.data.Views() {
v = append(v, view...)
}
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index c6fa8da5f..6728370c3 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -15,9 +15,11 @@ package stack
import (
"bytes"
+ "fmt"
"testing"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
func TestPacketHeaderPush(t *testing.T) {
@@ -110,7 +112,7 @@ func TestPacketHeaderPush(t *testing.T) {
if got, want := pk.Size(), allHdrSize+len(test.data); got != want {
t.Errorf("After pk.Size() = %d, want %d", got, want)
}
- checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data)
+ checkData(t, pk, test.data)
checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...),
concatViews(test.link, test.network, test.transport, test.data))
// Check the after values for each header.
@@ -204,7 +206,7 @@ func TestPacketHeaderConsume(t *testing.T) {
transport = test.data[test.link+test.network:][:test.transport]
payload = test.data[allHdrSize:]
)
- checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload)
+ checkData(t, pk, payload)
checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data)
// Check the after values for each header.
checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link)
@@ -340,6 +342,158 @@ func TestPacketHeaderConsumeThenPushPanics(t *testing.T) {
}
}
+func TestPacketBufferData(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ makePkt func(*testing.T) *PacketBuffer
+ data string
+ }{
+ {
+ name: "inbound packet",
+ makePkt: func(*testing.T) *PacketBuffer {
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ Data: vv("aabbbbccccccDATA"),
+ })
+ pkt.LinkHeader().Consume(2)
+ pkt.NetworkHeader().Consume(4)
+ pkt.TransportHeader().Consume(6)
+ return pkt
+ },
+ data: "DATA",
+ },
+ {
+ name: "outbound packet",
+ makePkt: func(*testing.T) *PacketBuffer {
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: 12,
+ Data: vv("DATA"),
+ })
+ copy(pkt.TransportHeader().Push(6), []byte("cccccc"))
+ copy(pkt.NetworkHeader().Push(4), []byte("bbbb"))
+ copy(pkt.LinkHeader().Push(2), []byte("aa"))
+ return pkt
+ },
+ data: "DATA",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ // PullUp
+ for _, n := range []int{1, len(tc.data)} {
+ t.Run(fmt.Sprintf("PullUp%d", n), func(t *testing.T) {
+ pkt := tc.makePkt(t)
+ v, ok := pkt.Data().PullUp(n)
+ wantV := []byte(tc.data)[:n]
+ if !ok || !bytes.Equal(v, wantV) {
+ t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want %q, true", n, v, ok, wantV)
+ }
+ })
+ }
+ t.Run("PullUpOutOfBounds", func(t *testing.T) {
+ n := len(tc.data) + 1
+ pkt := tc.makePkt(t)
+ v, ok := pkt.Data().PullUp(n)
+ if ok || v != nil {
+ t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want nil, false", n, v, ok)
+ }
+ })
+
+ // TrimFront
+ for _, n := range []int{1, len(tc.data)} {
+ t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) {
+ pkt := tc.makePkt(t)
+ pkt.Data().TrimFront(n)
+
+ checkData(t, pkt, []byte(tc.data)[n:])
+ })
+ }
+
+ // CapLength
+ for _, n := range []int{0, 1, len(tc.data)} {
+ t.Run(fmt.Sprintf("CapLength%d", n), func(t *testing.T) {
+ pkt := tc.makePkt(t)
+ pkt.Data().CapLength(n)
+
+ want := []byte(tc.data)
+ if n < len(want) {
+ want = want[:n]
+ }
+ checkData(t, pkt, want)
+ })
+ }
+
+ // Views
+ t.Run("Views", func(t *testing.T) {
+ pkt := tc.makePkt(t)
+ checkData(t, pkt, []byte(tc.data))
+ })
+
+ // AppendView
+ t.Run("AppendView", func(t *testing.T) {
+ s := "APPEND"
+
+ pkt := tc.makePkt(t)
+ pkt.Data().AppendView(buffer.View(s))
+
+ checkData(t, pkt, []byte(tc.data+s))
+ })
+
+ // ReadFromData/VV
+ for _, n := range []int{0, 1, 2, 7, 10, 14, 20} {
+ t.Run(fmt.Sprintf("ReadFromData%d", n), func(t *testing.T) {
+ s := "TO READ"
+ otherPkt := NewPacketBuffer(PacketBufferOptions{
+ Data: vv(s, s),
+ })
+ s += s
+
+ pkt := tc.makePkt(t)
+ pkt.Data().ReadFromData(otherPkt.Data(), n)
+
+ if n < len(s) {
+ s = s[:n]
+ }
+ checkData(t, pkt, []byte(tc.data+s))
+ })
+ t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) {
+ s := "TO READ"
+ srcVV := vv(s, s)
+ s += s
+
+ pkt := tc.makePkt(t)
+ pkt.Data().ReadFromVV(&srcVV, n)
+
+ if n < len(s) {
+ s = s[:n]
+ }
+ checkData(t, pkt, []byte(tc.data+s))
+ })
+ }
+
+ // ExtractVV
+ t.Run("ExtractVV", func(t *testing.T) {
+ pkt := tc.makePkt(t)
+ extractedVV := pkt.Data().ExtractVV()
+
+ got := extractedVV.ToOwnedView()
+ want := []byte(tc.data)
+ if !bytes.Equal(got, want) {
+ t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want)
+ }
+ })
+
+ // Replace
+ t.Run("Replace", func(t *testing.T) {
+ s := "REPLACED"
+
+ pkt := tc.makePkt(t)
+ pkt.Data().Replace(vv(s))
+
+ checkData(t, pkt, []byte(s))
+ })
+ })
+ }
+}
+
func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) {
t.Helper()
reserved := opts.ReserveHeaderBytes
@@ -356,7 +510,7 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO
if got, want := pk.Size(), len(data); got != want {
t.Errorf("Initial pk.Size() = %d, want %d", got, want)
}
- checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data)
+ checkData(t, pk, data)
checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data)
// Check the initial values for each header.
checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil)
@@ -383,6 +537,70 @@ func checkViewEqual(t *testing.T, what string, got, want buffer.View) {
}
}
+func checkData(t *testing.T, pkt *PacketBuffer, want []byte) {
+ t.Helper()
+ if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) {
+ t.Errorf("pkt.Data().Views() = %x, want %x", got, want)
+ }
+ if got := pkt.Data().Size(); got != len(want) {
+ t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want))
+ }
+
+ t.Run("AsRange", func(t *testing.T) {
+ // Full range
+ checkRange(t, pkt.Data().AsRange(), want)
+
+ // SubRange
+ for _, off := range []int{0, 1, len(want), len(want) + 1} {
+ t.Run(fmt.Sprintf("SubRange%d", off), func(t *testing.T) {
+ // Empty when off is greater than the size of range.
+ var sub []byte
+ if off < len(want) {
+ sub = want[off:]
+ }
+ checkRange(t, pkt.Data().AsRange().SubRange(off), sub)
+ })
+ }
+
+ // Capped
+ for _, n := range []int{0, 1, len(want), len(want) + 1} {
+ t.Run(fmt.Sprintf("Capped%d", n), func(t *testing.T) {
+ sub := want
+ if n < len(sub) {
+ sub = sub[:n]
+ }
+ checkRange(t, pkt.Data().AsRange().Capped(n), sub)
+ })
+ }
+ })
+}
+
+func checkRange(t *testing.T, r Range, data []byte) {
+ if got, want := r.Size(), len(data); got != want {
+ t.Errorf("r.Size() = %d, want %d", got, want)
+ }
+ if got := r.AsView(); !bytes.Equal(got, data) {
+ t.Errorf("r.AsView() = %x, want %x", got, data)
+ }
+ if got := r.ToOwnedView(); !bytes.Equal(got, data) {
+ t.Errorf("r.ToOwnedView() = %x, want %x", got, data)
+ }
+ if got, want := r.Checksum(), header.Checksum(data, 0 /* initial */); got != want {
+ t.Errorf("r.Checksum() = %x, want %x", got, want)
+ }
+}
+
+func vv(pieces ...string) buffer.VectorisedView {
+ var views []buffer.View
+ var size int
+ for _, p := range pieces {
+ v := buffer.View([]byte(p))
+ size += len(v)
+ views = append(views, v)
+ }
+ return buffer.NewVectorisedView(size, views)
+}
+
func makeView(size int) buffer.View {
b := byte(size)
return bytes.Repeat([]byte{b}, size)
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 8e39e828c..f45cf5fdf 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -137,11 +137,11 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
- nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
+ nb, ok := pkt.Data().PullUp(fakeNetHeaderLen)
if !ok {
return
}
- pkt.Data.TrimFront(fakeNetHeaderLen)
+ pkt.Data().TrimFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportError(
tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
@@ -4294,7 +4294,7 @@ func TestWritePacketToRemote(t *testing.T) {
if pkt.Route.RemoteLinkAddress != linkAddr2 {
t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2)
}
- if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" {
+ if diff := cmp.Diff(pkt.Pkt.Data().AsRange().ToOwnedView(), buffer.View(test.payload)); diff != "" {
t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff)
}
})