diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_targets.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 248 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer_test.go | 224 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 6 |
7 files changed, 482 insertions, 26 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index cdb435644..3f083928f 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -407,12 +407,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - length := uint16(len(tcpHeader) + pkt.Data.Size()) + length := uint16(len(tcpHeader) + pkt.Data().Size()) xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumVV(pkt.Data, xsum) + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index d63e9757c..0e8b90c9b 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -153,7 +153,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs if r.RequiresTXTransportChecksum() { length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumVV(pkt.Data, xsum) + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 740bdac28..78a4cb072 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -573,7 +573,11 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(tgt) snmc := header.SolicitedNodeAddr(tgt) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: header.IPv6Any, + Dst: snmc, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -619,7 +623,11 @@ func TestDADFail(t *testing.T) { na.Options().Serialize(header.NDPOptionsSerializer{ header.NDPTargetLinkLayerAddressOption(linkAddr1), }) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: tgt, + Dst: header.IPv6AllNodesMulticastAddress, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -973,7 +981,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo } opts := ra.Options() opts.Serialize(optSer) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: ip, + Dst: header.IPv6AllNodesMulticastAddress, + })) payloadLength := hdr.UsedLength() iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) iph.Encode(&header.IPv6Fields{ diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f9323d545..62f7c880e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -725,12 +725,12 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n.mu.RUnlock() n.stats.DisabledRx.Packets.Increment() - n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data.Size())) + n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size())) return } n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) + n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size())) networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { @@ -881,7 +881,7 @@ func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - transHeader, ok := pkt.Data.PullUp(8) + transHeader, ok := pkt.Data().PullUp(8) if !ok { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 4f013b212..8f288675d 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -59,7 +59,7 @@ type PacketBuffer struct { // PacketBuffers. PacketBufferEntry - // Data holds the payload of the packet. + // data holds the payload of the packet. // // For inbound packets, Data is initially the whole packet. Then gets moved to // headers via PacketHeader.Consume, when the packet is being parsed. @@ -69,7 +69,7 @@ type PacketBuffer struct { // // The bytes backing Data are immutable, a.k.a. users shouldn't write to its // backing storage. - Data buffer.VectorisedView + data buffer.VectorisedView // headers stores metadata about each header. headers [numHeaderType]headerInfo @@ -127,7 +127,7 @@ type PacketBuffer struct { // NewPacketBuffer creates a new PacketBuffer with opts. func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { pk := &PacketBuffer{ - Data: opts.Data, + data: opts.Data, } if opts.ReserveHeaderBytes != 0 { pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) @@ -184,13 +184,18 @@ func (pk *PacketBuffer) HeaderSize() int { // Size returns the size of packet in bytes. func (pk *PacketBuffer) Size() int { - return pk.HeaderSize() + pk.Data.Size() + return pk.HeaderSize() + pk.data.Size() } // MemSize returns the estimation size of the pk in memory, including backing // buffer data. func (pk *PacketBuffer) MemSize() int { - return pk.HeaderSize() + pk.Data.MemSize() + packetBufferStructSize + return pk.HeaderSize() + pk.data.MemSize() + packetBufferStructSize +} + +// Data returns the handle to data portion of pk. +func (pk *PacketBuffer) Data() PacketData { + return PacketData{pk: pk} } // Views returns the underlying storage of the whole packet. @@ -204,7 +209,7 @@ func (pk *PacketBuffer) Views() []buffer.View { } } - dataViews := pk.Data.Views() + dataViews := pk.data.Views() var vs []buffer.View if useHeader { @@ -242,11 +247,11 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum if h.buf != nil { panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) } - v, ok := pk.Data.PullUp(size) + v, ok := pk.data.PullUp(size) if !ok { return } - pk.Data.TrimFront(size) + pk.data.TrimFront(size) h.buf = v return h.buf, true } @@ -258,7 +263,7 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum func (pk *PacketBuffer) Clone() *PacketBuffer { return &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, - Data: pk.Data.Clone(nil), + data: pk.data.Clone(nil), headers: pk.headers, header: pk.header, Hash: pk.Hash, @@ -339,13 +344,234 @@ func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { return h.pk.consume(h.typ, size) } +// PacketData represents the data portion of a PacketBuffer. +type PacketData struct { + pk *PacketBuffer +} + +// PullUp returns a contiguous view of size bytes from the beginning of d. +// Callers should not write to or keep the view for later use. +func (d PacketData) PullUp(size int) (buffer.View, bool) { + return d.pk.data.PullUp(size) +} + +// TrimFront removes count from the beginning of d. It panics if count > +// d.Size(). +func (d PacketData) TrimFront(count int) { + d.pk.data.TrimFront(count) +} + +// CapLength reduces d to at most length bytes. +func (d PacketData) CapLength(length int) { + d.pk.data.CapLength(length) +} + +// Views returns the underlying storage of d in a slice of Views. Caller should +// not modify the returned slice. +func (d PacketData) Views() []buffer.View { + return d.pk.data.Views() +} + +// AppendView appends v into d, taking the ownership of v. +func (d PacketData) AppendView(v buffer.View) { + d.pk.data.AppendView(v) +} + +// ReadFromData moves at most count bytes from the beginning of srcData to the +// end of d and returns the number of bytes moved. +func (d PacketData) ReadFromData(srcData PacketData, count int) int { + return srcData.pk.data.ReadToVV(&d.pk.data, count) +} + +// ReadFromVV moves at most count bytes from the beginning of srcVV to the end +// of d and returns the number of bytes moved. +func (d PacketData) ReadFromVV(srcVV *buffer.VectorisedView, count int) int { + return srcVV.ReadToVV(&d.pk.data, count) +} + +// Size returns the number of bytes in the data payload of the packet. +func (d PacketData) Size() int { + return d.pk.data.Size() +} + +// AsRange returns a Range representing the current data payload of the packet. +func (d PacketData) AsRange() Range { + return Range{ + pk: d.pk, + offset: d.pk.HeaderSize(), + length: d.Size(), + } +} + +// ExtractVV returns a VectorisedView of d. This method has the semantic to +// destruct the underlying packet, hence the packet cannot be used again. +// +// This method exists for compatibility between PacketBuffer and VectorisedView. +// It may be removed later and should be used with care. +func (d PacketData) ExtractVV() buffer.VectorisedView { + return d.pk.data +} + +// Replace replaces the data portion of the packet with vv, taking the ownership +// of vv. +// +// This method exists for compatibility between PacketBuffer and VectorisedView. +// It may be removed later and should be used with care. +func (d PacketData) Replace(vv buffer.VectorisedView) { + d.pk.data = vv +} + +// Range represents a contiguous subportion of a PacketBuffer. +type Range struct { + pk *PacketBuffer + offset int + length int +} + +// Size returns the number of bytes in r. +func (r Range) Size() int { + return r.length +} + +// SubRange returns a new Range starting at off bytes of r. It returns an empty +// range if off is out-of-bounds. +func (r Range) SubRange(off int) Range { + if off > r.length { + return Range{pk: r.pk} + } + return Range{ + pk: r.pk, + offset: r.offset + off, + length: r.length - off, + } +} + +// Capped returns a new Range with the same starting point of r and length +// capped at max. +func (r Range) Capped(max int) Range { + if r.length <= max { + return r + } + return Range{ + pk: r.pk, + offset: r.offset, + length: max, + } +} + +// AsView returns the backing storage of r if possible. It will allocate a new +// View if r spans multiple pieces internally. Caller should not write to the +// returned View in any way. +func (r Range) AsView() buffer.View { + var allocated bool + var v buffer.View + r.iterate(func(b []byte) { + if v == nil { + // v has not been assigned, allowing first view to be returned. + v = b + } else { + // v has been assigned. This range spans more than a view, a new view + // needs to be allocated. + if !allocated { + allocated = true + all := make([]byte, 0, r.length) + all = append(all, v...) + v = all + } + v = append(v, b...) + } + }) + return v +} + +// ToOwnedView returns a owned copy of data in r. +func (r Range) ToOwnedView() buffer.View { + if r.length == 0 { + return nil + } + all := make([]byte, 0, r.length) + r.iterate(func(b []byte) { + all = append(all, b...) + }) + return all +} + +// Checksum calculates the RFC 1071 checksum for the underlying bytes of r. +func (r Range) Checksum() uint16 { + var c header.Checksumer + r.iterate(c.Add) + return c.Checksum() +} + +// iterate calls fn for each piece in r. fn is always called with a non-empty +// slice. +func (r Range) iterate(fn func([]byte)) { + w := window{ + offset: r.offset, + length: r.length, + } + // Header portion. + for i := range r.pk.headers { + if b := w.process(r.pk.headers[i].buf); len(b) > 0 { + fn(b) + } + if w.isDone() { + break + } + } + // Data portion. + if !w.isDone() { + for _, v := range r.pk.data.Views() { + if b := w.process(v); len(b) > 0 { + fn(b) + } + if w.isDone() { + break + } + } + } +} + +// window represents contiguous region of byte stream. User would call process() +// to input bytes, and obtain a subslice that is inside the window. +type window struct { + offset int + length int +} + +// isDone returns true if the window has passed and further process() calls will +// always return an empty slice. This can be used to end processing early. +func (w *window) isDone() bool { + return w.length == 0 +} + +// process feeds b in and returns a subslice that is inside the window. The +// returned slice will be a subslice of b, and it does not keep b after method +// returns. This method may return an empty slice if nothing in b is inside the +// window. +func (w *window) process(b []byte) (inWindow []byte) { + if w.offset >= len(b) { + w.offset -= len(b) + return nil + } + if w.offset > 0 { + b = b[w.offset:] + w.offset = 0 + } + if w.length < len(b) { + b = b[:w.length] + } + w.length -= len(b) + return b +} + // PayloadSince returns packet payload starting from and including a particular // header. // // The returned View is owned by the caller - its backing buffer is separate // from the packet header's underlying packet buffer. func PayloadSince(h PacketHeader) buffer.View { - size := h.pk.Data.Size() + size := h.pk.data.Size() for _, hinfo := range h.pk.headers[h.typ:] { size += len(hinfo.buf) } @@ -356,7 +582,7 @@ func PayloadSince(h PacketHeader) buffer.View { v = append(v, hinfo.buf...) } - for _, view := range h.pk.Data.Views() { + for _, view := range h.pk.data.Views() { v = append(v, view...) } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index c6fa8da5f..6728370c3 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -15,9 +15,11 @@ package stack import ( "bytes" + "fmt" "testing" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) func TestPacketHeaderPush(t *testing.T) { @@ -110,7 +112,7 @@ func TestPacketHeaderPush(t *testing.T) { if got, want := pk.Size(), allHdrSize+len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data) + checkData(t, pk, test.data) checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), concatViews(test.link, test.network, test.transport, test.data)) // Check the after values for each header. @@ -204,7 +206,7 @@ func TestPacketHeaderConsume(t *testing.T) { transport = test.data[test.link+test.network:][:test.transport] payload = test.data[allHdrSize:] ) - checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload) + checkData(t, pk, payload) checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) // Check the after values for each header. checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) @@ -340,6 +342,158 @@ func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { } } +func TestPacketBufferData(t *testing.T) { + for _, tc := range []struct { + name string + makePkt func(*testing.T) *PacketBuffer + data string + }{ + { + name: "inbound packet", + makePkt: func(*testing.T) *PacketBuffer { + pkt := NewPacketBuffer(PacketBufferOptions{ + Data: vv("aabbbbccccccDATA"), + }) + pkt.LinkHeader().Consume(2) + pkt.NetworkHeader().Consume(4) + pkt.TransportHeader().Consume(6) + return pkt + }, + data: "DATA", + }, + { + name: "outbound packet", + makePkt: func(*testing.T) *PacketBuffer { + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: 12, + Data: vv("DATA"), + }) + copy(pkt.TransportHeader().Push(6), []byte("cccccc")) + copy(pkt.NetworkHeader().Push(4), []byte("bbbb")) + copy(pkt.LinkHeader().Push(2), []byte("aa")) + return pkt + }, + data: "DATA", + }, + } { + t.Run(tc.name, func(t *testing.T) { + // PullUp + for _, n := range []int{1, len(tc.data)} { + t.Run(fmt.Sprintf("PullUp%d", n), func(t *testing.T) { + pkt := tc.makePkt(t) + v, ok := pkt.Data().PullUp(n) + wantV := []byte(tc.data)[:n] + if !ok || !bytes.Equal(v, wantV) { + t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want %q, true", n, v, ok, wantV) + } + }) + } + t.Run("PullUpOutOfBounds", func(t *testing.T) { + n := len(tc.data) + 1 + pkt := tc.makePkt(t) + v, ok := pkt.Data().PullUp(n) + if ok || v != nil { + t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want nil, false", n, v, ok) + } + }) + + // TrimFront + for _, n := range []int{1, len(tc.data)} { + t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) { + pkt := tc.makePkt(t) + pkt.Data().TrimFront(n) + + checkData(t, pkt, []byte(tc.data)[n:]) + }) + } + + // CapLength + for _, n := range []int{0, 1, len(tc.data)} { + t.Run(fmt.Sprintf("CapLength%d", n), func(t *testing.T) { + pkt := tc.makePkt(t) + pkt.Data().CapLength(n) + + want := []byte(tc.data) + if n < len(want) { + want = want[:n] + } + checkData(t, pkt, want) + }) + } + + // Views + t.Run("Views", func(t *testing.T) { + pkt := tc.makePkt(t) + checkData(t, pkt, []byte(tc.data)) + }) + + // AppendView + t.Run("AppendView", func(t *testing.T) { + s := "APPEND" + + pkt := tc.makePkt(t) + pkt.Data().AppendView(buffer.View(s)) + + checkData(t, pkt, []byte(tc.data+s)) + }) + + // ReadFromData/VV + for _, n := range []int{0, 1, 2, 7, 10, 14, 20} { + t.Run(fmt.Sprintf("ReadFromData%d", n), func(t *testing.T) { + s := "TO READ" + otherPkt := NewPacketBuffer(PacketBufferOptions{ + Data: vv(s, s), + }) + s += s + + pkt := tc.makePkt(t) + pkt.Data().ReadFromData(otherPkt.Data(), n) + + if n < len(s) { + s = s[:n] + } + checkData(t, pkt, []byte(tc.data+s)) + }) + t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) { + s := "TO READ" + srcVV := vv(s, s) + s += s + + pkt := tc.makePkt(t) + pkt.Data().ReadFromVV(&srcVV, n) + + if n < len(s) { + s = s[:n] + } + checkData(t, pkt, []byte(tc.data+s)) + }) + } + + // ExtractVV + t.Run("ExtractVV", func(t *testing.T) { + pkt := tc.makePkt(t) + extractedVV := pkt.Data().ExtractVV() + + got := extractedVV.ToOwnedView() + want := []byte(tc.data) + if !bytes.Equal(got, want) { + t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want) + } + }) + + // Replace + t.Run("Replace", func(t *testing.T) { + s := "REPLACED" + + pkt := tc.makePkt(t) + pkt.Data().Replace(vv(s)) + + checkData(t, pkt, []byte(s)) + }) + }) + } +} + func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { t.Helper() reserved := opts.ReserveHeaderBytes @@ -356,7 +510,7 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO if got, want := pk.Size(), len(data); got != want { t.Errorf("Initial pk.Size() = %d, want %d", got, want) } - checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data) + checkData(t, pk, data) checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) // Check the initial values for each header. checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) @@ -383,6 +537,70 @@ func checkViewEqual(t *testing.T, what string, got, want buffer.View) { } } +func checkData(t *testing.T, pkt *PacketBuffer, want []byte) { + t.Helper() + if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) { + t.Errorf("pkt.Data().Views() = %x, want %x", got, want) + } + if got := pkt.Data().Size(); got != len(want) { + t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want)) + } + + t.Run("AsRange", func(t *testing.T) { + // Full range + checkRange(t, pkt.Data().AsRange(), want) + + // SubRange + for _, off := range []int{0, 1, len(want), len(want) + 1} { + t.Run(fmt.Sprintf("SubRange%d", off), func(t *testing.T) { + // Empty when off is greater than the size of range. + var sub []byte + if off < len(want) { + sub = want[off:] + } + checkRange(t, pkt.Data().AsRange().SubRange(off), sub) + }) + } + + // Capped + for _, n := range []int{0, 1, len(want), len(want) + 1} { + t.Run(fmt.Sprintf("Capped%d", n), func(t *testing.T) { + sub := want + if n < len(sub) { + sub = sub[:n] + } + checkRange(t, pkt.Data().AsRange().Capped(n), sub) + }) + } + }) +} + +func checkRange(t *testing.T, r Range, data []byte) { + if got, want := r.Size(), len(data); got != want { + t.Errorf("r.Size() = %d, want %d", got, want) + } + if got := r.AsView(); !bytes.Equal(got, data) { + t.Errorf("r.AsView() = %x, want %x", got, data) + } + if got := r.ToOwnedView(); !bytes.Equal(got, data) { + t.Errorf("r.ToOwnedView() = %x, want %x", got, data) + } + if got, want := r.Checksum(), header.Checksum(data, 0 /* initial */); got != want { + t.Errorf("r.Checksum() = %x, want %x", got, want) + } +} + +func vv(pieces ...string) buffer.VectorisedView { + var views []buffer.View + var size int + for _, p := range pieces { + v := buffer.View([]byte(p)) + size += len(v) + views = append(views, v) + } + return buffer.NewVectorisedView(size, views) +} + func makeView(size int) buffer.View { b := byte(size) return bytes.Repeat([]byte{b}, size) diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 8e39e828c..f45cf5fdf 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -137,11 +137,11 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + nb, ok := pkt.Data().PullUp(fakeNetHeaderLen) if !ok { return } - pkt.Data.TrimFront(fakeNetHeaderLen) + pkt.Data().TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), @@ -4294,7 +4294,7 @@ func TestWritePacketToRemote(t *testing.T) { if pkt.Route.RemoteLinkAddress != linkAddr2 { t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2) } - if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" { + if diff := cmp.Diff(pkt.Pkt.Data().AsRange().ToOwnedView(), buffer.View(test.payload)); diff != "" { t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) } }) |