summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/conntrack.go4
-rw-r--r--pkg/tcpip/stack/iptables_targets.go2
-rw-r--r--pkg/tcpip/stack/ndp_test.go18
-rw-r--r--pkg/tcpip/stack/nic.go6
-rw-r--r--pkg/tcpip/stack/packet_buffer.go248
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go224
-rw-r--r--pkg/tcpip/stack/stack_test.go6
7 files changed, 482 insertions, 26 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index cdb435644..3f083928f 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -407,12 +407,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- length := uint16(len(tcpHeader) + pkt.Data.Size())
+ length := uint16(len(tcpHeader) + pkt.Data().Size())
xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
} else if r.RequiresTXTransportChecksum() {
- xsum = header.ChecksumVV(pkt.Data, xsum)
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index d63e9757c..0e8b90c9b 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -153,7 +153,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
if r.RequiresTXTransportChecksum() {
length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- xsum = header.ChecksumVV(pkt.Data, xsum)
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 740bdac28..78a4cb072 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -573,7 +573,11 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) {
ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(tgt)
snmc := header.SolicitedNodeAddr(tgt)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: header.IPv6Any,
+ Dst: snmc,
+ }))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -619,7 +623,11 @@ func TestDADFail(t *testing.T) {
na.Options().Serialize(header.NDPOptionsSerializer{
header.NDPTargetLinkLayerAddressOption(linkAddr1),
})
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: tgt,
+ Dst: header.IPv6AllNodesMulticastAddress,
+ }))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
@@ -973,7 +981,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
}
opts := ra.Options()
opts.Serialize(optSer)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: ip,
+ Dst: header.IPv6AllNodesMulticastAddress,
+ }))
payloadLength := hdr.UsedLength()
iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
iph.Encode(&header.IPv6Fields{
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index f9323d545..62f7c880e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -725,12 +725,12 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n.mu.RUnlock()
n.stats.DisabledRx.Packets.Increment()
- n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
+ n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size()))
return
}
n.stats.Rx.Packets.Increment()
- n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
+ n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size()))
networkEndpoint, ok := n.networkEndpoints[protocol]
if !ok {
@@ -881,7 +881,7 @@ func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo
// ICMPv4 only guarantees that 8 bytes of the transport protocol will
// be present in the payload. We know that the ports are within the
// first 8 bytes for all known transport protocols.
- transHeader, ok := pkt.Data.PullUp(8)
+ transHeader, ok := pkt.Data().PullUp(8)
if !ok {
return
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 4f013b212..8f288675d 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -59,7 +59,7 @@ type PacketBuffer struct {
// PacketBuffers.
PacketBufferEntry
- // Data holds the payload of the packet.
+ // data holds the payload of the packet.
//
// For inbound packets, Data is initially the whole packet. Then gets moved to
// headers via PacketHeader.Consume, when the packet is being parsed.
@@ -69,7 +69,7 @@ type PacketBuffer struct {
//
// The bytes backing Data are immutable, a.k.a. users shouldn't write to its
// backing storage.
- Data buffer.VectorisedView
+ data buffer.VectorisedView
// headers stores metadata about each header.
headers [numHeaderType]headerInfo
@@ -127,7 +127,7 @@ type PacketBuffer struct {
// NewPacketBuffer creates a new PacketBuffer with opts.
func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer {
pk := &PacketBuffer{
- Data: opts.Data,
+ data: opts.Data,
}
if opts.ReserveHeaderBytes != 0 {
pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes)
@@ -184,13 +184,18 @@ func (pk *PacketBuffer) HeaderSize() int {
// Size returns the size of packet in bytes.
func (pk *PacketBuffer) Size() int {
- return pk.HeaderSize() + pk.Data.Size()
+ return pk.HeaderSize() + pk.data.Size()
}
// MemSize returns the estimation size of the pk in memory, including backing
// buffer data.
func (pk *PacketBuffer) MemSize() int {
- return pk.HeaderSize() + pk.Data.MemSize() + packetBufferStructSize
+ return pk.HeaderSize() + pk.data.MemSize() + packetBufferStructSize
+}
+
+// Data returns the handle to data portion of pk.
+func (pk *PacketBuffer) Data() PacketData {
+ return PacketData{pk: pk}
}
// Views returns the underlying storage of the whole packet.
@@ -204,7 +209,7 @@ func (pk *PacketBuffer) Views() []buffer.View {
}
}
- dataViews := pk.Data.Views()
+ dataViews := pk.data.Views()
var vs []buffer.View
if useHeader {
@@ -242,11 +247,11 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum
if h.buf != nil {
panic(fmt.Sprintf("consume must not be called twice: type %s", typ))
}
- v, ok := pk.Data.PullUp(size)
+ v, ok := pk.data.PullUp(size)
if !ok {
return
}
- pk.Data.TrimFront(size)
+ pk.data.TrimFront(size)
h.buf = v
return h.buf, true
}
@@ -258,7 +263,7 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum
func (pk *PacketBuffer) Clone() *PacketBuffer {
return &PacketBuffer{
PacketBufferEntry: pk.PacketBufferEntry,
- Data: pk.Data.Clone(nil),
+ data: pk.data.Clone(nil),
headers: pk.headers,
header: pk.header,
Hash: pk.Hash,
@@ -339,13 +344,234 @@ func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) {
return h.pk.consume(h.typ, size)
}
+// PacketData represents the data portion of a PacketBuffer.
+type PacketData struct {
+ pk *PacketBuffer
+}
+
+// PullUp returns a contiguous view of size bytes from the beginning of d.
+// Callers should not write to or keep the view for later use.
+func (d PacketData) PullUp(size int) (buffer.View, bool) {
+ return d.pk.data.PullUp(size)
+}
+
+// TrimFront removes count from the beginning of d. It panics if count >
+// d.Size().
+func (d PacketData) TrimFront(count int) {
+ d.pk.data.TrimFront(count)
+}
+
+// CapLength reduces d to at most length bytes.
+func (d PacketData) CapLength(length int) {
+ d.pk.data.CapLength(length)
+}
+
+// Views returns the underlying storage of d in a slice of Views. Caller should
+// not modify the returned slice.
+func (d PacketData) Views() []buffer.View {
+ return d.pk.data.Views()
+}
+
+// AppendView appends v into d, taking the ownership of v.
+func (d PacketData) AppendView(v buffer.View) {
+ d.pk.data.AppendView(v)
+}
+
+// ReadFromData moves at most count bytes from the beginning of srcData to the
+// end of d and returns the number of bytes moved.
+func (d PacketData) ReadFromData(srcData PacketData, count int) int {
+ return srcData.pk.data.ReadToVV(&d.pk.data, count)
+}
+
+// ReadFromVV moves at most count bytes from the beginning of srcVV to the end
+// of d and returns the number of bytes moved.
+func (d PacketData) ReadFromVV(srcVV *buffer.VectorisedView, count int) int {
+ return srcVV.ReadToVV(&d.pk.data, count)
+}
+
+// Size returns the number of bytes in the data payload of the packet.
+func (d PacketData) Size() int {
+ return d.pk.data.Size()
+}
+
+// AsRange returns a Range representing the current data payload of the packet.
+func (d PacketData) AsRange() Range {
+ return Range{
+ pk: d.pk,
+ offset: d.pk.HeaderSize(),
+ length: d.Size(),
+ }
+}
+
+// ExtractVV returns a VectorisedView of d. This method has the semantic to
+// destruct the underlying packet, hence the packet cannot be used again.
+//
+// This method exists for compatibility between PacketBuffer and VectorisedView.
+// It may be removed later and should be used with care.
+func (d PacketData) ExtractVV() buffer.VectorisedView {
+ return d.pk.data
+}
+
+// Replace replaces the data portion of the packet with vv, taking the ownership
+// of vv.
+//
+// This method exists for compatibility between PacketBuffer and VectorisedView.
+// It may be removed later and should be used with care.
+func (d PacketData) Replace(vv buffer.VectorisedView) {
+ d.pk.data = vv
+}
+
+// Range represents a contiguous subportion of a PacketBuffer.
+type Range struct {
+ pk *PacketBuffer
+ offset int
+ length int
+}
+
+// Size returns the number of bytes in r.
+func (r Range) Size() int {
+ return r.length
+}
+
+// SubRange returns a new Range starting at off bytes of r. It returns an empty
+// range if off is out-of-bounds.
+func (r Range) SubRange(off int) Range {
+ if off > r.length {
+ return Range{pk: r.pk}
+ }
+ return Range{
+ pk: r.pk,
+ offset: r.offset + off,
+ length: r.length - off,
+ }
+}
+
+// Capped returns a new Range with the same starting point of r and length
+// capped at max.
+func (r Range) Capped(max int) Range {
+ if r.length <= max {
+ return r
+ }
+ return Range{
+ pk: r.pk,
+ offset: r.offset,
+ length: max,
+ }
+}
+
+// AsView returns the backing storage of r if possible. It will allocate a new
+// View if r spans multiple pieces internally. Caller should not write to the
+// returned View in any way.
+func (r Range) AsView() buffer.View {
+ var allocated bool
+ var v buffer.View
+ r.iterate(func(b []byte) {
+ if v == nil {
+ // v has not been assigned, allowing first view to be returned.
+ v = b
+ } else {
+ // v has been assigned. This range spans more than a view, a new view
+ // needs to be allocated.
+ if !allocated {
+ allocated = true
+ all := make([]byte, 0, r.length)
+ all = append(all, v...)
+ v = all
+ }
+ v = append(v, b...)
+ }
+ })
+ return v
+}
+
+// ToOwnedView returns a owned copy of data in r.
+func (r Range) ToOwnedView() buffer.View {
+ if r.length == 0 {
+ return nil
+ }
+ all := make([]byte, 0, r.length)
+ r.iterate(func(b []byte) {
+ all = append(all, b...)
+ })
+ return all
+}
+
+// Checksum calculates the RFC 1071 checksum for the underlying bytes of r.
+func (r Range) Checksum() uint16 {
+ var c header.Checksumer
+ r.iterate(c.Add)
+ return c.Checksum()
+}
+
+// iterate calls fn for each piece in r. fn is always called with a non-empty
+// slice.
+func (r Range) iterate(fn func([]byte)) {
+ w := window{
+ offset: r.offset,
+ length: r.length,
+ }
+ // Header portion.
+ for i := range r.pk.headers {
+ if b := w.process(r.pk.headers[i].buf); len(b) > 0 {
+ fn(b)
+ }
+ if w.isDone() {
+ break
+ }
+ }
+ // Data portion.
+ if !w.isDone() {
+ for _, v := range r.pk.data.Views() {
+ if b := w.process(v); len(b) > 0 {
+ fn(b)
+ }
+ if w.isDone() {
+ break
+ }
+ }
+ }
+}
+
+// window represents contiguous region of byte stream. User would call process()
+// to input bytes, and obtain a subslice that is inside the window.
+type window struct {
+ offset int
+ length int
+}
+
+// isDone returns true if the window has passed and further process() calls will
+// always return an empty slice. This can be used to end processing early.
+func (w *window) isDone() bool {
+ return w.length == 0
+}
+
+// process feeds b in and returns a subslice that is inside the window. The
+// returned slice will be a subslice of b, and it does not keep b after method
+// returns. This method may return an empty slice if nothing in b is inside the
+// window.
+func (w *window) process(b []byte) (inWindow []byte) {
+ if w.offset >= len(b) {
+ w.offset -= len(b)
+ return nil
+ }
+ if w.offset > 0 {
+ b = b[w.offset:]
+ w.offset = 0
+ }
+ if w.length < len(b) {
+ b = b[:w.length]
+ }
+ w.length -= len(b)
+ return b
+}
+
// PayloadSince returns packet payload starting from and including a particular
// header.
//
// The returned View is owned by the caller - its backing buffer is separate
// from the packet header's underlying packet buffer.
func PayloadSince(h PacketHeader) buffer.View {
- size := h.pk.Data.Size()
+ size := h.pk.data.Size()
for _, hinfo := range h.pk.headers[h.typ:] {
size += len(hinfo.buf)
}
@@ -356,7 +582,7 @@ func PayloadSince(h PacketHeader) buffer.View {
v = append(v, hinfo.buf...)
}
- for _, view := range h.pk.Data.Views() {
+ for _, view := range h.pk.data.Views() {
v = append(v, view...)
}
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index c6fa8da5f..6728370c3 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -15,9 +15,11 @@ package stack
import (
"bytes"
+ "fmt"
"testing"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
func TestPacketHeaderPush(t *testing.T) {
@@ -110,7 +112,7 @@ func TestPacketHeaderPush(t *testing.T) {
if got, want := pk.Size(), allHdrSize+len(test.data); got != want {
t.Errorf("After pk.Size() = %d, want %d", got, want)
}
- checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data)
+ checkData(t, pk, test.data)
checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...),
concatViews(test.link, test.network, test.transport, test.data))
// Check the after values for each header.
@@ -204,7 +206,7 @@ func TestPacketHeaderConsume(t *testing.T) {
transport = test.data[test.link+test.network:][:test.transport]
payload = test.data[allHdrSize:]
)
- checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload)
+ checkData(t, pk, payload)
checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data)
// Check the after values for each header.
checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link)
@@ -340,6 +342,158 @@ func TestPacketHeaderConsumeThenPushPanics(t *testing.T) {
}
}
+func TestPacketBufferData(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ makePkt func(*testing.T) *PacketBuffer
+ data string
+ }{
+ {
+ name: "inbound packet",
+ makePkt: func(*testing.T) *PacketBuffer {
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ Data: vv("aabbbbccccccDATA"),
+ })
+ pkt.LinkHeader().Consume(2)
+ pkt.NetworkHeader().Consume(4)
+ pkt.TransportHeader().Consume(6)
+ return pkt
+ },
+ data: "DATA",
+ },
+ {
+ name: "outbound packet",
+ makePkt: func(*testing.T) *PacketBuffer {
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: 12,
+ Data: vv("DATA"),
+ })
+ copy(pkt.TransportHeader().Push(6), []byte("cccccc"))
+ copy(pkt.NetworkHeader().Push(4), []byte("bbbb"))
+ copy(pkt.LinkHeader().Push(2), []byte("aa"))
+ return pkt
+ },
+ data: "DATA",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ // PullUp
+ for _, n := range []int{1, len(tc.data)} {
+ t.Run(fmt.Sprintf("PullUp%d", n), func(t *testing.T) {
+ pkt := tc.makePkt(t)
+ v, ok := pkt.Data().PullUp(n)
+ wantV := []byte(tc.data)[:n]
+ if !ok || !bytes.Equal(v, wantV) {
+ t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want %q, true", n, v, ok, wantV)
+ }
+ })
+ }
+ t.Run("PullUpOutOfBounds", func(t *testing.T) {
+ n := len(tc.data) + 1
+ pkt := tc.makePkt(t)
+ v, ok := pkt.Data().PullUp(n)
+ if ok || v != nil {
+ t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want nil, false", n, v, ok)
+ }
+ })
+
+ // TrimFront
+ for _, n := range []int{1, len(tc.data)} {
+ t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) {
+ pkt := tc.makePkt(t)
+ pkt.Data().TrimFront(n)
+
+ checkData(t, pkt, []byte(tc.data)[n:])
+ })
+ }
+
+ // CapLength
+ for _, n := range []int{0, 1, len(tc.data)} {
+ t.Run(fmt.Sprintf("CapLength%d", n), func(t *testing.T) {
+ pkt := tc.makePkt(t)
+ pkt.Data().CapLength(n)
+
+ want := []byte(tc.data)
+ if n < len(want) {
+ want = want[:n]
+ }
+ checkData(t, pkt, want)
+ })
+ }
+
+ // Views
+ t.Run("Views", func(t *testing.T) {
+ pkt := tc.makePkt(t)
+ checkData(t, pkt, []byte(tc.data))
+ })
+
+ // AppendView
+ t.Run("AppendView", func(t *testing.T) {
+ s := "APPEND"
+
+ pkt := tc.makePkt(t)
+ pkt.Data().AppendView(buffer.View(s))
+
+ checkData(t, pkt, []byte(tc.data+s))
+ })
+
+ // ReadFromData/VV
+ for _, n := range []int{0, 1, 2, 7, 10, 14, 20} {
+ t.Run(fmt.Sprintf("ReadFromData%d", n), func(t *testing.T) {
+ s := "TO READ"
+ otherPkt := NewPacketBuffer(PacketBufferOptions{
+ Data: vv(s, s),
+ })
+ s += s
+
+ pkt := tc.makePkt(t)
+ pkt.Data().ReadFromData(otherPkt.Data(), n)
+
+ if n < len(s) {
+ s = s[:n]
+ }
+ checkData(t, pkt, []byte(tc.data+s))
+ })
+ t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) {
+ s := "TO READ"
+ srcVV := vv(s, s)
+ s += s
+
+ pkt := tc.makePkt(t)
+ pkt.Data().ReadFromVV(&srcVV, n)
+
+ if n < len(s) {
+ s = s[:n]
+ }
+ checkData(t, pkt, []byte(tc.data+s))
+ })
+ }
+
+ // ExtractVV
+ t.Run("ExtractVV", func(t *testing.T) {
+ pkt := tc.makePkt(t)
+ extractedVV := pkt.Data().ExtractVV()
+
+ got := extractedVV.ToOwnedView()
+ want := []byte(tc.data)
+ if !bytes.Equal(got, want) {
+ t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want)
+ }
+ })
+
+ // Replace
+ t.Run("Replace", func(t *testing.T) {
+ s := "REPLACED"
+
+ pkt := tc.makePkt(t)
+ pkt.Data().Replace(vv(s))
+
+ checkData(t, pkt, []byte(s))
+ })
+ })
+ }
+}
+
func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) {
t.Helper()
reserved := opts.ReserveHeaderBytes
@@ -356,7 +510,7 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO
if got, want := pk.Size(), len(data); got != want {
t.Errorf("Initial pk.Size() = %d, want %d", got, want)
}
- checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data)
+ checkData(t, pk, data)
checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data)
// Check the initial values for each header.
checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil)
@@ -383,6 +537,70 @@ func checkViewEqual(t *testing.T, what string, got, want buffer.View) {
}
}
+func checkData(t *testing.T, pkt *PacketBuffer, want []byte) {
+ t.Helper()
+ if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) {
+ t.Errorf("pkt.Data().Views() = %x, want %x", got, want)
+ }
+ if got := pkt.Data().Size(); got != len(want) {
+ t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want))
+ }
+
+ t.Run("AsRange", func(t *testing.T) {
+ // Full range
+ checkRange(t, pkt.Data().AsRange(), want)
+
+ // SubRange
+ for _, off := range []int{0, 1, len(want), len(want) + 1} {
+ t.Run(fmt.Sprintf("SubRange%d", off), func(t *testing.T) {
+ // Empty when off is greater than the size of range.
+ var sub []byte
+ if off < len(want) {
+ sub = want[off:]
+ }
+ checkRange(t, pkt.Data().AsRange().SubRange(off), sub)
+ })
+ }
+
+ // Capped
+ for _, n := range []int{0, 1, len(want), len(want) + 1} {
+ t.Run(fmt.Sprintf("Capped%d", n), func(t *testing.T) {
+ sub := want
+ if n < len(sub) {
+ sub = sub[:n]
+ }
+ checkRange(t, pkt.Data().AsRange().Capped(n), sub)
+ })
+ }
+ })
+}
+
+func checkRange(t *testing.T, r Range, data []byte) {
+ if got, want := r.Size(), len(data); got != want {
+ t.Errorf("r.Size() = %d, want %d", got, want)
+ }
+ if got := r.AsView(); !bytes.Equal(got, data) {
+ t.Errorf("r.AsView() = %x, want %x", got, data)
+ }
+ if got := r.ToOwnedView(); !bytes.Equal(got, data) {
+ t.Errorf("r.ToOwnedView() = %x, want %x", got, data)
+ }
+ if got, want := r.Checksum(), header.Checksum(data, 0 /* initial */); got != want {
+ t.Errorf("r.Checksum() = %x, want %x", got, want)
+ }
+}
+
+func vv(pieces ...string) buffer.VectorisedView {
+ var views []buffer.View
+ var size int
+ for _, p := range pieces {
+ v := buffer.View([]byte(p))
+ size += len(v)
+ views = append(views, v)
+ }
+ return buffer.NewVectorisedView(size, views)
+}
+
func makeView(size int) buffer.View {
b := byte(size)
return bytes.Repeat([]byte{b}, size)
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 8e39e828c..f45cf5fdf 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -137,11 +137,11 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
- nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
+ nb, ok := pkt.Data().PullUp(fakeNetHeaderLen)
if !ok {
return
}
- pkt.Data.TrimFront(fakeNetHeaderLen)
+ pkt.Data().TrimFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportError(
tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
@@ -4294,7 +4294,7 @@ func TestWritePacketToRemote(t *testing.T) {
if pkt.Route.RemoteLinkAddress != linkAddr2 {
t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2)
}
- if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" {
+ if diff := cmp.Diff(pkt.Pkt.Data().AsRange().ToOwnedView(), buffer.View(test.payload)); diff != "" {
t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff)
}
})