diff options
54 files changed, 960 insertions, 331 deletions
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index b05e81526..f4a30effd 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -196,7 +196,7 @@ func (vv *VectorisedView) CapLength(length int) { // If the buffer argument is large enough to contain all the Views of this // VectorisedView, the method will avoid allocations and use the buffer to // store the Views of the clone. -func (vv *VectorisedView) Clone(buffer []View) VectorisedView { +func (vv VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } @@ -290,6 +290,14 @@ func (vv *VectorisedView) AppendView(v View) { vv.size += len(v) } +// AppendViews appends views to vv. +func (vv *VectorisedView) AppendViews(views []View) { + vv.views = append(vv.views, views...) + for _, v := range views { + vv.size += len(v) + } +} + // Readers returns a bytes.Reader for each of vv's views. func (vv *VectorisedView) Readers() []bytes.Reader { readers := make([]bytes.Reader, 0, len(vv.views)) diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index 78b2faa26..d296d9c2b 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -45,6 +45,11 @@ func vv(size int, pieces ...string) buffer.VectorisedView { return buffer.NewVectorisedView(size, views) } +// v returns a buffer.View containing piece. +func v(piece string) buffer.View { + return buffer.View(piece) +} + var capLengthTestCases = []struct { comment string in buffer.VectorisedView @@ -125,6 +130,12 @@ var trimFrontTestCases = []struct { want: vv(1, "3"), }, { + comment: "Case with one empty Views", + in: vv(3, "1", "", "23"), + count: 2, + want: vv(1, "3"), + }, + { comment: "Corner case with negative count", in: vv(1, "1"), count: -1, @@ -566,11 +577,11 @@ func TestAppendView(t *testing.T) { in buffer.View want buffer.VectorisedView }{ - {buffer.VectorisedView{}, nil, buffer.VectorisedView{}}, - {buffer.VectorisedView{}, buffer.View{}, buffer.VectorisedView{}}, - {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), nil, buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}})}, - {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), buffer.View{}, buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}})}, - {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), buffer.View{'e'}, buffer.NewVectorisedView(5, []buffer.View{{'a', 'b', 'c', 'd'}, {'e'}})}, + {vv(0), nil, vv(0)}, + {vv(0), v(""), vv(0)}, + {vv(4, "abcd"), nil, vv(4, "abcd")}, + {vv(4, "abcd"), v(""), vv(4, "abcd")}, + {vv(4, "abcd"), v("e"), vv(5, "abcd", "e")}, } for _, tc := range testCases { tc.vv.AppendView(tc.in) @@ -580,6 +591,31 @@ func TestAppendView(t *testing.T) { } } +func TestAppendViews(t *testing.T) { + testCases := []struct { + vv buffer.VectorisedView + in []buffer.View + want buffer.VectorisedView + }{ + {vv(0), nil, vv(0)}, + {vv(0), []buffer.View{}, vv(0)}, + {vv(0), []buffer.View{v("")}, vv(0, "")}, + {vv(4, "abcd"), nil, vv(4, "abcd")}, + {vv(4, "abcd"), []buffer.View{}, vv(4, "abcd")}, + {vv(4, "abcd"), []buffer.View{v("")}, vv(4, "abcd", "")}, + {vv(4, "abcd"), []buffer.View{v("")}, vv(4, "abcd", "")}, + {vv(4, "abcd"), []buffer.View{v("e")}, vv(5, "abcd", "e")}, + {vv(4, "abcd"), []buffer.View{v("e"), v("fg")}, vv(7, "abcd", "e", "fg")}, + {vv(4, "abcd"), []buffer.View{v(""), v("fg")}, vv(6, "abcd", "", "fg")}, + } + for _, tc := range testCases { + tc.vv.AppendViews(tc.in) + if got, want := tc.vv, tc.want; !reflect.DeepEqual(got, want) { + t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want) + } + } +} + func TestMemSize(t *testing.T) { const perViewCap = 128 views := make([]buffer.View, 2, 32) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 07b4393a4..75d8e1f03 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -985,7 +985,11 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker { } icmp := header.ICMPv6(last.Payload()) - if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want { + if got, want := icmp.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: last.SourceAddress(), + Dst: last.DestinationAddress(), + }); got != want { t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want) } diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 14a4b2b44..6aa9acfa8 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -186,42 +186,29 @@ func Checksum(buf []byte, initial uint16) uint16 { // // The initial checksum must have been computed on an even number of bytes. func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 { - return ChecksumVVWithOffset(vv, initial, 0, vv.Size()) + var c Checksumer + for _, v := range vv.Views() { + c.Add([]byte(v)) + } + return ChecksumCombine(initial, c.Checksum()) } -// ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the -// bytes in the given VectorizedView. -// -// The initial checksum must have been computed on an even number of bytes. -func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 { - odd := false - sum := initial - for _, v := range vv.Views() { - if len(v) == 0 { - continue - } - - if off >= len(v) { - off -= len(v) - continue - } - v = v[off:] - - l := len(v) - if l > size { - l = size - } - v = v[:l] - - sum, odd = unrolledCalculateChecksum(v, odd, uint32(sum)) - - size -= len(v) - if size == 0 { - break - } - off = 0 +// Checksumer calculates checksum defined in RFC 1071. +type Checksumer struct { + sum uint16 + odd bool +} + +// Add adds b to checksum. +func (c *Checksumer) Add(b []byte) { + if len(b) > 0 { + c.sum, c.odd = unrolledCalculateChecksum(b, c.odd, uint32(c.sum)) } - return sum +} + +// Checksum returns the latest checksum value. +func (c *Checksumer) Checksum() uint16 { + return c.sum } // ChecksumCombine combines the two uint16 to form their checksum. This is done diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index 5ab20ee86..d267dabd0 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -17,6 +17,7 @@ package header_test import ( + "bytes" "fmt" "math/rand" "sync" @@ -26,86 +27,72 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -func TestChecksumVVWithOffset(t *testing.T) { +func TestChecksumer(t *testing.T) { testCases := []struct { - name string - vv buffer.VectorisedView - off, size int - initial uint16 - want uint16 + name string + data [][]byte + want uint16 }{ { name: "empty", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - }), - off: 0, - size: 0, want: 0, }, { - name: "OneView", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - }), - off: 0, - size: 5, + name: "OneOddView", + data: [][]byte{ + []byte{1, 9, 0, 5, 4}, + }, want: 1294, }, { - name: "TwoViews", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 0, - size: 11, + name: "TwoOddViews", + data: [][]byte{ + []byte{1, 9, 0, 5, 4}, + []byte{4, 3, 7, 1, 2, 123}, + }, want: 33819, }, { - name: "TwoViewsWithOffset", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 1, - size: 11, - want: 33819, + name: "OneEvenView", + data: [][]byte{ + []byte{1, 9, 0, 5}, + }, + want: 270, }, { - name: "ThreeViewsWithOffset", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 7, - size: 11, - want: 33819, + name: "TwoEvenViews", + data: [][]byte{ + buffer.NewViewFromBytes([]byte{98, 1, 9, 0}), + buffer.NewViewFromBytes([]byte{9, 0, 5, 4}), + }, + want: 30981, }, { - name: "ThreeViewsWithInitial", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}), - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}), - }), - initial: 77, - off: 7, - size: 11, - want: 33896, + name: "ThreeViews", + data: [][]byte{ + []byte{77, 11, 33, 0, 55, 44}, + []byte{98, 1, 9, 0, 5, 4}, + []byte{4, 3, 7, 1, 2, 123, 99}, + }, + want: 34236, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want { - t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want) + var all bytes.Buffer + var c header.Checksumer + for _, b := range tc.data { + c.Add(b) + // Append to the buffer. We will check the checksum as a whole later. + if _, err := all.Write(b); err != nil { + t.Fatalf("all.Write(b) = _, %s; want _, nil", err) + } + } + if got, want := c.Checksum(), tc.want; got != want { + t.Errorf("c.Checksum() = %d, want %d", got, want) } - v := tc.vv.ToView() - v.TrimFront(tc.off) - v.CapLength(tc.size) - if got, want := header.Checksum(v, tc.initial), tc.want; got != want { - t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want) + if got, want := header.Checksum(all.Bytes(), 0 /* initial */), tc.want; got != want { + t.Errorf("Checksum(flatten tc.data) = %d, want %d", got, want) } }) } @@ -228,7 +215,7 @@ func TestICMPv4Checksum(t *testing.T) { h.SetChecksum(want) testICMPChecksum(t, h.Checksum, func() uint16 { - return header.ICMPv4Checksum(h, vv) + return header.ICMPv4Checksum(h, header.ChecksumVV(vv, 0)) }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) } @@ -260,6 +247,12 @@ func TestICMPv6Checksum(t *testing.T) { h.SetChecksum(want) testICMPChecksum(t, h.Checksum, func() uint16 { - return header.ICMPv6Checksum(h, src, dst, vv) + return header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: h, + Src: src, + Dst: dst, + PayloadCsum: header.ChecksumVV(vv, 0), + PayloadLen: vv.Size(), + }) }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) } diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index f840a4322..91c1c3cd2 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -18,7 +18,6 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv4 represents an ICMPv4 header stored in a byte array. @@ -198,8 +197,8 @@ func (b ICMPv4) SetSequence(sequence uint16) { // ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header, // and payload. -func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { - xsum := ChecksumVV(vv, 0) +func ICMPv4Checksum(h ICMPv4, payloadCsum uint16) uint16 { + xsum := payloadCsum // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. xsum = Checksum(h[:2], xsum) diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index eca9750ab..668da623a 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -18,7 +18,6 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv6 represents an ICMPv6 header stored in a byte array. @@ -262,12 +261,22 @@ func (b ICMPv6) Payload() []byte { return b[ICMPv6PayloadOffset:] } +// ICMPv6ChecksumParams contains parameters to calculate ICMPv6 checksum. +type ICMPv6ChecksumParams struct { + Header ICMPv6 + Src tcpip.Address + Dst tcpip.Address + PayloadCsum uint16 + PayloadLen int +} + // ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header, // IPv6 src/dst addresses and the payload. -func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { - xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) +func ICMPv6Checksum(params ICMPv6ChecksumParams) uint16 { + h := params.Header - xsum = ChecksumVV(vv, xsum) + xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src, params.Dst, uint16(len(h)+params.PayloadLen)) + xsum = ChecksumCombine(xsum, params.PayloadCsum) // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. xsum = Checksum(h[:2], xsum) diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go index 2042f214a..ebb4b2c1d 100644 --- a/pkg/tcpip/header/parse/parse.go +++ b/pkg/tcpip/header/parse/parse.go @@ -41,7 +41,7 @@ func ARP(pkt *stack.PacketBuffer) bool { // // Returns true if the header was successfully parsed. func IPv4(pkt *stack.PacketBuffer) bool { - hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + hdr, ok := pkt.Data().PullUp(header.IPv4MinimumSize) if !ok { return false } @@ -62,27 +62,29 @@ func IPv4(pkt *stack.PacketBuffer) bool { ipHdr = header.IPv4(hdr) pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) + pkt.Data().CapLength(int(ipHdr.TotalLength()) - len(hdr)) return true } // IPv6 parses an IPv6 packet found in pkt.Data and populates pkt's network // header with the IPv6 header. func IPv6(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, fragID uint32, fragOffset uint16, fragMore bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + hdr, ok := pkt.Data().PullUp(header.IPv6MinimumSize) if !ok { return 0, 0, 0, false, false } ipHdr := header.IPv6(hdr) - // dataClone consists of: + // Create a VV to parse the packet. We don't plan to modify anything here. + // dataVV consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. // - Any other payload data. views := [8]buffer.View{} - dataClone := pkt.Data.Clone(views[:]) - dataClone.TrimFront(header.IPv6MinimumSize) - it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) + dataVV := buffer.NewVectorisedView(0, views[:0]) + dataVV.AppendViews(pkt.Data().Views()) + dataVV.TrimFront(header.IPv6MinimumSize) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataVV) // Iterate over the IPv6 extensions to find their length. var nextHdr tcpip.TransportProtocolNumber @@ -98,7 +100,7 @@ traverseExtensions: // If we exhaust the extension list, the entire packet is the IPv6 header // and (possibly) extensions. if done { - extensionsSize = dataClone.Size() + extensionsSize = dataVV.Size() break } @@ -110,12 +112,12 @@ traverseExtensions: fragMore = extHdr.More() } rawPayload := it.AsRawHeader(true /* consume */) - extensionsSize = dataClone.Size() - rawPayload.Buf.Size() + extensionsSize = dataVV.Size() - rawPayload.Buf.Size() break traverseExtensions case header.IPv6RawPayloadHeader: // We've found the payload after any extensions. - extensionsSize = dataClone.Size() - extHdr.Buf.Size() + extensionsSize = dataVV.Size() - extHdr.Buf.Size() nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) break traverseExtensions @@ -127,10 +129,10 @@ traverseExtensions: // Put the IPv6 header with extensions in pkt.NetworkHeader(). hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) if !ok { - panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) + panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data().Size())) } ipHdr = header.IPv6(hdr) - pkt.Data.CapLength(int(ipHdr.PayloadLength())) + pkt.Data().CapLength(int(ipHdr.PayloadLength())) pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber return nextHdr, fragID, fragOffset, fragMore, true @@ -153,13 +155,13 @@ func UDP(pkt *stack.PacketBuffer) bool { func TCP(pkt *stack.PacketBuffer) bool { // TCP header is variable length, peek at it first. hdrLen := header.TCPMinimumSize - hdr, ok := pkt.Data.PullUp(hdrLen) + hdr, ok := pkt.Data().PullUp(hdrLen) if !ok { return false } // If the header has options, pull those up as well. - if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { + if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data().Size() { // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of // packets. hdrLen = offset diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 72d3f70ac..e17e2085c 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -427,7 +427,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen vnetHdr.csumOffset = gso.CsumOffset } - if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS { + if gso.Type != stack.GSONone && uint16(pkt.Data().Size()) > gso.MSS { switch gso.Type { case stack.GSOTCPv4: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 @@ -468,7 +468,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset } - if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data.Size()) > pkt.GSOOptions.MSS { + if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS { switch pkt.GSOOptions.Type { case stack.GSOTCPv4: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 358a030d2..1e40f3fef 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -67,7 +67,7 @@ func checkPacketInfoEqual(t *testing.T, got, want packetInfo) { LinkHeader: pk.LinkHeader().View(), NetworkHeader: pk.NetworkHeader().View(), TransportHeader: pk.TransportHeader().View(), - Data: pk.Data.ToView(), + Data: pk.Data().AsRange().ToOwnedView(), } }), ); diff != "" { @@ -616,8 +616,8 @@ func TestDispatchPacketFormat(t *testing.T) { if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want { t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want) } - if got, want := pkt.Data.Size(), 4; got != want { - t.Errorf("pkt.Data.Size() = %d, want %d", got, want) + if got, want := pkt.Data().Size(), 4; got != want { + t.Errorf("pkt.Data().Size() = %d, want %d", got, want) } }) } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 736871d1c..46df87f44 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -165,7 +165,7 @@ func (d *readVDispatcher) dispatch() (bool, tcpip.Error) { // We don't get any indication of what the packet is, so try to guess // if it's an IPv4 or IPv6 packet. // IP version information is at the first octet, so pulling up 1 byte. - h, ok := pkt.Data.PullUp(1) + h, ok := pkt.Data().PullUp(1) if !ok { return true, nil } @@ -270,7 +270,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) { // We don't get any indication of what the packet is, so try to guess // if it's an IPv4 or IPv6 packet. // IP version information is at the first octet, so pulling up 1 byte. - h, ok := pkt.Data.PullUp(1) + h, ok := pkt.Data().PullUp(1) if !ok { // Skip this packet. continue diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index def47772f..d4b3ddd5c 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -80,7 +80,7 @@ func (q *queueBuffers) cleanup() { type packetInfo struct { addr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - vv buffer.VectorisedView + data buffer.View linkHeader buffer.View } @@ -136,7 +136,7 @@ func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.L c.packets = append(c.packets, packetInfo{ addr: remoteLinkAddr, proto: proto, - vv: pkt.Data.Clone(nil), + data: pkt.Data().AsRange().ToOwnedView(), }) c.mu.Unlock() @@ -676,7 +676,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.ToView()) + rcvd := []byte(c.packets[0].data) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index bd2b8d4bf..84189bba5 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -290,7 +290,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - hdr, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + hdr, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize) if !ok { break } @@ -327,7 +327,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe case header.ICMPv6ProtocolNumber: transName = "icmp" - hdr, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) + hdr, ok := pkt.Data().PullUp(header.ICMPv6MinimumSize) if !ok { break } @@ -387,7 +387,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) break } - if size := pkt.Data.Size() + len(tcp); offset > size && !moreFragments { + if size := pkt.Data().Size() + len(tcp); offset > size && !moreFragments { details += fmt.Sprintf("invalid packet: tcp data offset %d larger than tcp packet length %d", offset, size) break } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 3829ca9c9..c1678c4f4 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -281,7 +281,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { vv.AppendView(info.Pkt.NetworkHeader().View()) vv.AppendView(info.Pkt.TransportHeader().View()) // Append data payload. - vv.Append(info.Pkt.Data) + vv.Append(info.Pkt.Data().ExtractVV()) return vv.ToView(), true } diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation.go b/pkg/tcpip/network/internal/fragmentation/fragmentation.go index 243738951..5168f5361 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation.go @@ -170,7 +170,7 @@ func (f *Fragmentation) Process( return nil, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) } - if l := pkt.Data.Size(); l != int(fragmentSize) { + if l := pkt.Data().Size(); l != int(fragmentSize) { return nil, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) } @@ -293,7 +293,7 @@ func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, re // these headers. var fragmentableData buffer.VectorisedView fragmentableData.AppendView(pkt.TransportHeader().View()) - fragmentableData.Append(pkt.Data) + fragmentableData.Append(pkt.Data().ExtractVV()) fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen return PacketFragmenter{ @@ -323,7 +323,7 @@ func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, }) // Copy data for the fragment. - copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen) + copied := fragPkt.Data().ReadFromVV(&pf.data, pf.fragmentPayloadLen) offset := pf.fragmentOffset pf.fragmentOffset += copied diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go index 47ea3173e..7daf64b4a 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go @@ -121,7 +121,7 @@ func TestFragmentationProcess(t *testing.T) { in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done) } if c.out[i].done { - if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data.ToOwnedView()); diff != "" { + if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data().AsRange().ToOwnedView()); diff != "" { t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s", in.id, in.first, in.last, in.more, in.proto, in.pkt, diff) } @@ -470,9 +470,7 @@ func TestPacketFragmenter(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto) - var originalPayload buffer.VectorisedView - originalPayload.AppendView(pkt.TransportHeader().View()) - originalPayload.Append(pkt.Data) + originalPayload := stack.PayloadSince(pkt.TransportHeader()) var reassembledPayload buffer.VectorisedView pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve) for i := 0; ; i++ { @@ -499,7 +497,7 @@ func TestPacketFragmenter(t *testing.T) { if got := fragPkt.TransportHeader().View().Size(); got != 0 { t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got) } - reassembledPayload.Append(fragPkt.Data) + reassembledPayload.AppendViews(fragPkt.Data().Views()) if !more { if i != len(test.wantFragments)-1 { t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1) @@ -507,7 +505,7 @@ func TestPacketFragmenter(t *testing.T) { break } } - if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" { + if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload); diff != "" { t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) } }) @@ -625,11 +623,11 @@ func TestTimeoutHandler(t *testing.T) { } switch { case handler.pkt != nil && test.wantPkt == nil: - t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data.ToView()) + t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data().AsRange().ToOwnedView()) case handler.pkt == nil && test.wantPkt != nil: - t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data.ToView()) + t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data().AsRange().ToOwnedView()) case handler.pkt != nil && test.wantPkt != nil: - if diff := cmp.Diff(test.wantPkt.Data.ToView(), handler.pkt.Data.ToView()); diff != "" { + if diff := cmp.Diff(test.wantPkt.Data().AsRange().ToOwnedView(), handler.pkt.Data().AsRange().ToOwnedView()); diff != "" { t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff) } } diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go index 933d63d32..90075a70c 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go @@ -167,8 +167,8 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s resPkt := r.holes[0].pkt for i := 1; i < len(r.holes); i++ { - fragPkt := r.holes[i].pkt - fragPkt.Data.ReadToVV(&resPkt.Data, fragPkt.Data.Size()) + fragData := r.holes[i].pkt.Data() + resPkt.Data().ReadFromData(fragData, fragData.Size()) } return resPkt, r.proto, true, memConsumed, nil } diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go index 214a93709..cfd9f00ef 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go @@ -204,7 +204,7 @@ func TestReassemblerProcess(t *testing.T) { if a == nil || b == nil { return a == b } - return bytes.Equal(a.Data.ToOwnedView(), b.Data.ToOwnedView()) + return bytes.Equal(a.Data().AsRange().ToOwnedView(), b.Data().AsRange().ToOwnedView()) } if isDone { diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 90236ed9e..aee1652fa 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -90,8 +90,7 @@ type testObject struct { // checkValues verifies that the transport protocol, data contents, src & dst // addresses of a packet match what's expected. If any field doesn't match, the // test fails. -func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) { - v := vv.ToView() +func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, v buffer.View, srcAddr, dstAddr tcpip.Address) { if protocol != t.protocol { t.t.Errorf("protocol = %v, want %v", protocol, t.protocol) } @@ -120,7 +119,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // parsing are expected. func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { netHdr := pkt.Network() - t.checkValues(protocol, pkt.Data, netHdr.SourceAddress(), netHdr.DestinationAddress()) + t.checkValues(protocol, pkt.Data().AsRange().ToOwnedView(), netHdr.SourceAddress(), netHdr.DestinationAddress()) t.dataCalls++ return stack.TransportPacketHandled } @@ -129,7 +128,7 @@ func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumb // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) { - t.checkValues(trans, pkt.Data, remote, local) + t.checkValues(trans, pkt.Data().AsRange().ToOwnedView(), remote, local) if diff := cmp.Diff( t.transErr, transportError{ @@ -198,7 +197,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } - t.checkValues(prot, pkt.Data, srcAddr, dstAddr) + t.checkValues(prot, pkt.Data().AsRange().ToOwnedView(), srcAddr, dstAddr) return nil } @@ -371,7 +370,11 @@ func TestSourceAddressValidation(t *testing.T) { pkt.SetType(header.ICMPv6EchoRequest) pkt.SetCode(0) pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: src, + Dst: localIPv6Addr, + })) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: header.ICMPv6MinimumSize, @@ -1199,7 +1202,11 @@ func TestIPv6ReceiveControl(t *testing.T) { nic.testObject.transErr = c.transErr // Set ICMPv6 checksum. - icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: outerSrcAddr, + Dst: localIPv6Addr, + })) addressableEndpoint, ok := ep.(stack.AddressableEndpoint) if !ok { diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 5bf7809e8..deb104837 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -137,7 +137,7 @@ func (e *endpoint) checkLocalAddress(addr tcpip.Address) bool { // is used to find out which transport endpoint must be notified about the ICMP // packet. We only expect the payload, not the enclosing ICMP packet. func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + h, ok := pkt.Data().PullUp(header.IPv4MinimumSize) if !ok { return } @@ -156,7 +156,7 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet } hlen := int(hdr.HeaderLength()) - if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { + if pkt.Data().Size() < hlen || hdr.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -164,7 +164,7 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet } // Skip the ip header, then deliver the error. - pkt.Data.TrimFront(hlen) + pkt.Data().TrimFront(hlen) p := hdr.TransportProtocol() e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt) } @@ -174,7 +174,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. - v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize) if !ok { received.invalid.Increment() return @@ -182,7 +182,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { h := header.ICMPv4(v) // Only do in-stack processing if the checksum is correct. - if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff { + if pkt.Data().AsRange().Checksum() != 0xffff { received.invalid.Increment() // It's possible that a raw socket expects to receive this regardless // of checksum errors. If it's an echo request we know it's safe because @@ -253,7 +253,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no // waiting endpoints. Consider moving responsibility for doing the copy to // DeliverTransportPacket so that is is only done when needed. - replyData := pkt.Data.ToOwnedView() + replyData := pkt.Data().AsRange().ToOwnedView() ipHdr := header.IPv4(pkt.NetworkHeader().View()) localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast @@ -336,7 +336,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4DstUnreachable: received.dstUnreachable.Increment() - pkt.Data.TrimFront(header.ICMPv4MinimumSize) + pkt.Data().TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4HostUnreachable: e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) @@ -571,7 +571,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip return nil } - payloadLen := len(origIPHdr) + transportHeader.Size() + pkt.Data.Size() + payloadLen := len(origIPHdr) + transportHeader.Size() + pkt.Data().Size() if payloadLen > available { payloadLen = available } @@ -586,8 +586,11 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip newHeader := append(buffer.View(nil), origIPHdr...) newHeader = append(newHeader, transportHeader...) payload := newHeader.ToVectorisedView() - payload.AppendView(pkt.Data.ToView()) - payload.CapLength(payloadLen) + if dataCap := payloadLen - payload.Size(); dataCap > 0 { + payload.AppendView(pkt.Data().AsRange().Capped(dataCap).ToOwnedView()) + } else { + payload.CapLength(payloadLen) + } icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize, @@ -623,7 +626,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum())) if err := route.WritePacket( nil, /* gso */ diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index 0a15ae897..f3fc1c87e 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -197,7 +197,7 @@ func (igmp *igmpState) isPacketValidLocked(pkt *stack.PacketBuffer, messageType // Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer, hasRouterAlertOption bool) { received := igmp.ep.stats.igmp.packetsReceived - headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize) + headerView, ok := pkt.Data().PullUp(header.IGMPMinimumSize) if !ok { received.invalid.Increment() return @@ -210,7 +210,7 @@ func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer, hasRouterAlertOption // same set of octets, including the checksum field. If the result // is all 1 bits (-0 in 1's complement arithmetic), the check // succeeds. - if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xFFFF { + if pkt.Data().AsRange().Checksum() != 0xFFFF { received.checksumErrors.Increment() return } diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index c5f68e411..e5e1b89cc 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -106,9 +106,9 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma igmp.SetGroupAddress(groupAddress) igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) - e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } // TestIGMPV1Present tests the node's ability to fallback to V1 when a V1 diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 4a429ea6c..cabe274d6 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -492,7 +492,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + h, ok := pkt.Data().PullUp(header.IPv4MinimumSize) if !ok { return &tcpip.ErrMalformedHeader{} } @@ -502,14 +502,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return &tcpip.ErrMalformedHeader{} } - h, ok = pkt.Data.PullUp(int(hdrLen)) + h, ok = pkt.Data().PullUp(int(hdrLen)) if !ok { return &tcpip.ErrMalformedHeader{} } ip := header.IPv4(h) // Always set the total length. - pktSize := pkt.Data.Size() + pktSize := pkt.Data().Size() ip.SetTotalLength(uint16(pktSize)) // Set the source address when zero. @@ -687,7 +687,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { stats := e.stats h := header.IPv4(pkt.NetworkHeader().View()) - if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { + if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { stats.ip.MalformedPacketsReceived.Increment() return } @@ -765,7 +765,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { } if h.More() || h.FragmentOffset() != 0 { - if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { + if pkt.Data().Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. stats.ip.MalformedPacketsReceived.Increment() @@ -793,10 +793,10 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // maximum payload size. // // Note that this addition doesn't overflow even on 32bit architecture - // because pkt.Data.Size() should not exceed 65535 (the max IP datagram + // because pkt.Data().Size() should not exceed 65535 (the max IP datagram // size). Otherwise the packet would've been rejected as invalid before // reaching here. - if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize { + if int(start)+pkt.Data().Size() > header.IPv4MaximumPayloadSize { stats.ip.MalformedPacketsReceived.Increment() stats.ip.MalformedFragmentsReceived.Increment() return @@ -813,7 +813,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { Protocol: proto, }, start, - start+uint16(pkt.Data.Size())-1, + start+uint16(pkt.Data().Size())-1, h.More(), proto, pkt, @@ -831,7 +831,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // The reassembler doesn't take care of fixing up the header, so we need // to do it here. - h.SetTotalLength(uint16(pkt.Data.Size() + len((h)))) + h.SetTotalLength(uint16(pkt.Data().Size() + len((h)))) h.SetFlagsFragmentOffset(0, 0) } stats.ip.PacketsDelivered.Increment() @@ -1186,7 +1186,7 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error } func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { - payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + payload := pkt.TransportHeader().View().Size() + pkt.Data().Size() return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index dc4db6e5f..26d9696d7 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -1211,7 +1211,7 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset) } reassembledPayload.AppendView(packet.TransportHeader().View()) - reassembledPayload.Append(packet.Data) + reassembledPayload.AppendView(packet.Data().AsRange().ToOwnedView()) // Clear out the checksum and length from the ip because we can't compare // it. sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 5f44ab317..e80e681da 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -18,7 +18,6 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -165,7 +164,7 @@ func (e *endpoint) checkLocalAddress(addr tcpip.Address) bool { // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + h, ok := pkt.Data().PullUp(header.IPv6MinimumSize) if !ok { return } @@ -184,10 +183,10 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // Skip the IP header, then handle the fragmentation header if there // is one. - pkt.Data.TrimFront(header.IPv6MinimumSize) + pkt.Data().TrimFront(header.IPv6MinimumSize) p := hdr.TransportProtocol() if p == header.IPv6FragmentHeader { - f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) + f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize) if !ok { return } @@ -200,7 +199,7 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // Skip fragmentation header and find out the actual protocol // number. - pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) + pkt.Data().TrimFront(header.IPv6FragmentHeaderSize) p = fragHdr.TransportProtocol() } @@ -268,7 +267,7 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP if routerAlert == nil || routerAlert.Value != header.IPv6RouterAlertMLD { return false } - if pkt.Data.Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize { + if pkt.Data().Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize { return false } if iph.HopLimit() != header.MLDHopLimit { @@ -285,7 +284,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r received := e.stats.icmp.packetsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader // fields set. See icmp/protocol.go:protocol.Parse for a full explanation. - v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) + v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize) if !ok { received.invalid.Increment() return @@ -296,11 +295,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r dstAddr := iph.DestinationAddress() // Validate ICMPv6 checksum before processing the packet. - // - // This copy is used as extra payload during the checksum calculation. - payload := pkt.Data.Clone(nil) - payload.TrimFront(len(h)) - if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want { + payload := pkt.Data().AsRange().SubRange(len(h)) + if got, want := h.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: h, + Src: srcAddr, + Dst: dstAddr, + PayloadCsum: payload.Checksum(), + PayloadLen: payload.Size(), + }); got != want { received.invalid.Increment() return } @@ -320,12 +322,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r switch icmpType := h.Type(); icmpType { case header.ICMPv6PacketTooBig: received.packetTooBig.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) + hdr, ok := pkt.Data().PullUp(header.ICMPv6PacketTooBigMinimumSize) if !ok { received.invalid.Increment() return } - pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) + pkt.Data().TrimFront(header.ICMPv6PacketTooBigMinimumSize) networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) if err != nil { networkMTU = 0 @@ -334,12 +336,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6DstUnreachable: received.dstUnreachable.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) + hdr, ok := pkt.Data().PullUp(header.ICMPv6DstUnreachableMinimumSize) if !ok { received.invalid.Increment() return } - pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) + pkt.Data().TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch header.ICMPv6(hdr).Code() { case header.ICMPv6NetworkUnreachable: e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt) @@ -348,16 +350,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r } case header.ICMPv6NeighborSolicit: received.neighborSolicit.Increment() - if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize { + if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborSolicitMinimumSize { received.invalid.Increment() return } // The remainder of payload must be only the neighbor solicitation, so - // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, + // payload.AsView() always returns the solicitation. Per RFC 6980 section 5, // NDP messages cannot be fragmented. Also note that in the common case NDP - // datagrams are very small and ToView() will not incur allocations. - ns := header.NDPNeighborSolicit(payload.ToView()) + // datagrams are very small and AsView() will not incur allocations. + ns := header.NDPNeighborSolicit(payload.AsView()) targetAddr := ns.TargetAddress() // As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast @@ -529,7 +531,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r na.SetOverrideFlag(true) na.SetTargetAddress(targetAddr) na.Options().Serialize(optsSerializer) - packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + packet.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: packet, + Src: r.LocalAddress, + Dst: r.RemoteAddress, + })) // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) // @@ -545,16 +551,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6NeighborAdvert: received.neighborAdvert.Increment() - if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertMinimumSize { + if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborAdvertMinimumSize { received.invalid.Increment() return } // The remainder of payload must be only the neighbor advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section + // payload.AsView() always returns the advertisement. Per RFC 6980 section // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - na := header.NDPNeighborAdvert(payload.ToView()) + // NDP datagrams are very small and AsView() will not incur allocations. + na := header.NDPNeighborAdvert(payload.AsView()) targetAddr := na.TargetAddress() e.dad.mu.Lock() @@ -657,13 +663,20 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, - Data: pkt.Data, + Data: pkt.Data().ExtractVV(), }) - packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) + icmp := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber - copy(packet, icmpHdr) - packet.SetType(header.ICMPv6EchoReply) - packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) + copy(icmp, icmpHdr) + icmp.SetType(header.ICMPv6EchoReply) + dataRange := replyPkt.Data().AsRange() + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: r.LocalAddress, + Dst: r.RemoteAddress, + PayloadCsum: dataRange.Checksum(), + PayloadLen: dataRange.Size(), + })) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), @@ -676,7 +689,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6EchoReply: received.echoReply.Increment() - if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { + if pkt.Data().Size() < header.ICMPv6EchoMinimumSize { received.invalid.Increment() return } @@ -696,7 +709,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // Is the NDP payload of sufficient size to hold a Router Solictation? - if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { + if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { received.invalid.Increment() return } @@ -710,9 +723,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - // Note that in the common case NDP datagrams are very small and ToView() + // Note that in the common case NDP datagrams are very small and AsView() // will not incur allocations. - rs := header.NDPRouterSolicit(payload.ToView()) + rs := header.NDPRouterSolicit(payload.AsView()) it, err := rs.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. @@ -756,7 +769,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // Is the NDP payload of sufficient size to hold a Router Advertisement? - if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { + if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { received.invalid.Increment() return } @@ -770,9 +783,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - // Note that in the common case NDP datagrams are very small and ToView() + // Note that in the common case NDP datagrams are very small and AsView() // will not incur allocations. - ra := header.NDPRouterAdvert(payload.ToView()) + ra := header.NDPRouterAdvert(payload.AsView()) it, err := ra.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. @@ -850,11 +863,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r switch icmpType { case header.ICMPv6MulticastListenerQuery: e.mu.Lock() - e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView())) + e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.AsView())) e.mu.Unlock() case header.ICMPv6MulticastListenerReport: e.mu.Lock() - e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView())) + e.mu.mld.handleMulticastListenerReport(header.MLD(payload.AsView())) e.mu.Unlock() case header.ICMPv6MulticastListenerDone: default: @@ -1077,13 +1090,13 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip if available < header.IPv6MinimumSize { return nil } - payloadLen := network.Size() + transport.Size() + pkt.Data.Size() + payloadLen := network.Size() + transport.Size() + pkt.Data().Size() if payloadLen > available { payloadLen = available } payload := network.ToVectorisedView() payload.AppendView(transport) - payload.Append(pkt.Data) + payload.Append(pkt.Data().ExtractVV()) payload.CapLength(payloadLen) newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -1115,7 +1128,14 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, route.LocalAddress, route.RemoteAddress, newPkt.Data)) + dataRange := newPkt.Data().AsRange() + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: route.LocalAddress, + Dst: route.RemoteAddress, + PayloadCsum: dataRange.Checksum(), + PayloadLen: dataRange.Size(), + })) if err := route.WritePacket( nil, /* gso */ stack.NetworkHeaderParams{ diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index c27164344..d4e63710c 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -324,7 +324,13 @@ func TestICMPCounts(t *testing.T) { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp[:typ.size], + Src: lladdr0, + Dst: lladdr1, + PayloadCsum: header.Checksum(typ.extraData, 0 /* initial */), + PayloadLen: len(typ.extraData), + })) handleICMPInIPv6(ep, lladdr1, lladdr0, icmp, typ.hopLimit, typ.includeRouterAlert) } @@ -498,7 +504,11 @@ func TestLinkResolution(t *testing.T) { hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: r.LocalAddress, + Dst: r.RemoteAddress, + })) // We can't send our payload directly over the route because that // doesn't provoke NDP discovery. @@ -687,7 +697,11 @@ func TestICMPChecksumValidationSimple(t *testing.T) { copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) if checksum { - icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: lladdr1, + Dst: lladdr0, + })) } ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -879,7 +893,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { payloadFn(icmpHdr.Payload()) if checksum { - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, buffer.VectorisedView{})) + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: lladdr1, + Dst: lladdr0, + })) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -1058,7 +1076,13 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { payloadFn(payload) if checksum { - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, payload.ToVectorisedView())) + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: lladdr1, + Dst: lladdr0, + PayloadCsum: header.Checksum(payload, 0 /* initial */), + PayloadLen: len(payload), + })) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -1324,7 +1348,11 @@ func TestPacketQueing(t *testing.T) { pkt.SetType(header.ICMPv6EchoRequest) pkt.SetCode(0) pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: host2IPv6Addr.AddressWithPrefix.Address, + Dst: host1IPv6Addr.AddressWithPrefix.Address, + })) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: header.ICMPv6MinimumSize, @@ -1422,7 +1450,11 @@ func TestPacketQueing(t *testing.T) { na.Options().Serialize(header.NDPOptionsSerializer{ header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr), }) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: host2IPv6Addr.AddressWithPrefix.Address, + Dst: host1IPv6Addr.AddressWithPrefix.Address, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -1661,7 +1693,11 @@ func TestCallsToNeighborCache(t *testing.T) { } icmp := test.createPacket() - icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: test.source, + Dst: test.destination, + })) handleICMPInIPv6(ep, test.source, test.destination, icmp, header.NDPHopLimit, false) // Confirm the endpoint calls the correct NUDHandler method. diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 7638ade35..544717678 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -619,7 +619,7 @@ func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params } func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { - payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + payload := pkt.TransportHeader().View().Size() + pkt.Data().Size() return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU } @@ -819,14 +819,14 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { // The packet already has an IP header, but there are a few required checks. - h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + h, ok := pkt.Data().PullUp(header.IPv6MinimumSize) if !ok { return &tcpip.ErrMalformedHeader{} } ip := header.IPv6(h) // Always set the payload length. - pktSize := pkt.Data.Size() + pktSize := pkt.Data().Size() ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize)) // Set the source address when zero. @@ -964,7 +964,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { stats := e.stats.ip h := header.IPv6(pkt.NetworkHeader().View()) - if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { + if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { stats.MalformedPacketsReceived.Increment() return } @@ -993,13 +993,14 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { return } + // Create a VV to parse the packet. We don't plan to modify anything here. // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. // - Any other payload data. vv := pkt.NetworkHeader().View()[header.IPv6MinimumSize:].ToVectorisedView() vv.AppendView(pkt.TransportHeader().View()) - vv.Append(pkt.Data) + vv.AppendViews(pkt.Data().Views()) it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) // iptables filtering. All packets that reach here are intended for @@ -1257,7 +1258,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // have more extension headers in the reassembled payload, as per RFC // 8200 section 4.5. We also use the NextHeader value from the first // fragment. - it = header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(proto), pkt.Data) + data := pkt.Data() + dataVV := buffer.NewVectorisedView(data.Size(), data.Views()) + it = header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(proto), dataVV) } case header.IPv6DestinationOptionsExtHdr: @@ -1314,7 +1317,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) - pkt.Data = extHdr.Buf + pkt.Data().Replace(extHdr.Buf) stats.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 05c9f4dbf..266a53e3b 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -68,7 +68,11 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertMinimumSize)) pkt.SetType(header.ICMPv6NeighborAdvert) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: src, + Dst: dst, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -216,7 +220,7 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB // Store the reassembled payload as we parse each fragment. The payload // includes the Transport header and everything after. reassembledPayload.AppendView(fragment.TransportHeader().View()) - reassembledPayload.Append(fragment.Data) + reassembledPayload.AppendView(fragment.Data().AsRange().ToOwnedView()) } if diff := cmp.Diff(buffer.View(source[sourceIPHeadersLen:]), reassembledPayload.ToView()); diff != "" { @@ -3065,7 +3069,11 @@ func TestForwarding(t *testing.T) { icmp.SetType(header.ICMPv6EchoRequest) icmp.SetCode(header.ICMPv6UnusedCode) icmp.SetChecksum(0) - icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: remoteIPv6Addr1, + Dst: remoteIPv6Addr2, + })) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: header.ICMPv6MinimumSize, diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index 205e36cdd..dd153466d 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -236,7 +236,11 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp localAddress = header.IPv6Any } - icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: localAddress, + Dst: destAddress, + })) extensionHeaders := header.IPv6ExtHdrSerializer{ header.IPv6SerializableHopByHopExtHdr{ diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index f1b8d58f2..9a425e50a 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -326,11 +326,15 @@ func createAndInjectMLDPacket(e *channel.Endpoint, mldType header.ICMPv6Type, ho mld := header.MLD(icmp.MessageBody()) mld.SetMaximumResponseDelay(0) mld.SetMulticastAddress(header.IPv6Any) - icmp.SetChecksum(header.ICMPv6Checksum(icmp, srcAddress, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: srcAddress, + Dst: header.IPv6AllNodesMulticastAddress, + })) - e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } func TestMLDPacketValidation(t *testing.T) { diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 721269c58..c22f60709 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -1721,7 +1721,11 @@ func (ndp *ndpState) startSolicitingRouters() { icmpData.SetType(header.ICMPv6RouterSolicit) rs := header.NDPRouterSolicit(icmpData.MessageBody()) rs.Options().Serialize(optsSerializer) - icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpData, + Src: localAddr, + Dst: header.IPv6AllRoutersMulticastAddress, + })) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()), @@ -1812,7 +1816,11 @@ func (e *endpoint) sendNDPNS(srcAddr, dstAddr, targetAddr tcpip.Address, remoteL ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(targetAddr) ns.Options().Serialize(opts) - icmp.SetChecksum(header.ICMPv6Checksum(icmp, srcAddr, dstAddr, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: srcAddr, + Dst: dstAddr, + })) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(e.MaxHeaderLength()), diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index ce20af0e3..0d53c260d 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -215,7 +215,11 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { ns.SetTargetAddress(lladdr0) opts := ns.Options() copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: lladdr1, + Dst: lladdr0, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -478,7 +482,11 @@ func TestNeighborSolicitationResponse(t *testing.T) { ns.SetTargetAddress(nicAddr) opts := ns.Options() opts.Serialize(test.nsOpts) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: test.nsSrc, + Dst: test.nsDst, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -554,7 +562,11 @@ func TestNeighborSolicitationResponse(t *testing.T) { na.SetOverrideFlag(true) na.SetTargetAddress(test.nsSrc) na.Options().Serialize(ser) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: test.nsSrc, + Dst: nicAddr, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -657,7 +669,11 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { ns.SetTargetAddress(lladdr1) opts := ns.Options() copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: lladdr1, + Dst: lladdr0, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -874,7 +890,13 @@ func TestNDPValidation(t *testing.T) { copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp[:typ.size], + Src: lladdr0, + Dst: lladdr1, + PayloadCsum: header.Checksum(typ.extraData /* initial */, 0), + PayloadLen: len(typ.extraData), + })) // Rx count of the NDP message should initially be 0. if got := typStat.Value(); got != 0 { @@ -987,7 +1009,11 @@ func TestNeighborAdvertisementValidation(t *testing.T) { na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetTargetAddress(lladdr1) na.SetSolicitedFlag(test.solicitedFlag) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, test.ipDstAddr, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: lladdr1, + Dst: test.ipDstAddr, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -1182,7 +1208,11 @@ func TestRouterAdvertValidation(t *testing.T) { pkt.SetCode(test.code) copy(pkt.MessageBody(), test.ndpPayload) payloadLength := hdr.UsedLength() - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: test.src, + Dst: header.IPv6AllNodesMulticastAddress, + })) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index 73913aef8..ecd5003a7 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -230,9 +230,9 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime b igmp.SetGroupAddress(groupAddress) igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) - e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } // createAndInjectMLDPacket creates and injects an MLD packet with the @@ -263,11 +263,15 @@ func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay b mld := header.MLD(icmp.MessageBody()) mld.SetMaximumResponseDelay(uint16(maxRespDelay)) mld.SetMulticastAddress(groupAddress) - icmp.SetChecksum(header.ICMPv6Checksum(icmp, linkLocalIPv6Addr2, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: linkLocalIPv6Addr2, + Dst: header.IPv6AllNodesMulticastAddress, + })) - e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } // TestMGPDisabled tests that the multicast group protocol is not enabled by diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index cdb435644..3f083928f 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -407,12 +407,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - length := uint16(len(tcpHeader) + pkt.Data.Size()) + length := uint16(len(tcpHeader) + pkt.Data().Size()) xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumVV(pkt.Data, xsum) + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index d63e9757c..0e8b90c9b 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -153,7 +153,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs if r.RequiresTXTransportChecksum() { length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumVV(pkt.Data, xsum) + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 740bdac28..78a4cb072 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -573,7 +573,11 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(tgt) snmc := header.SolicitedNodeAddr(tgt) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: header.IPv6Any, + Dst: snmc, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -619,7 +623,11 @@ func TestDADFail(t *testing.T) { na.Options().Serialize(header.NDPOptionsSerializer{ header.NDPTargetLinkLayerAddressOption(linkAddr1), }) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: tgt, + Dst: header.IPv6AllNodesMulticastAddress, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -973,7 +981,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo } opts := ra.Options() opts.Serialize(optSer) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: ip, + Dst: header.IPv6AllNodesMulticastAddress, + })) payloadLength := hdr.UsedLength() iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) iph.Encode(&header.IPv6Fields{ diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f9323d545..62f7c880e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -725,12 +725,12 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n.mu.RUnlock() n.stats.DisabledRx.Packets.Increment() - n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data.Size())) + n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size())) return } n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) + n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size())) networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { @@ -881,7 +881,7 @@ func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - transHeader, ok := pkt.Data.PullUp(8) + transHeader, ok := pkt.Data().PullUp(8) if !ok { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 4f013b212..8f288675d 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -59,7 +59,7 @@ type PacketBuffer struct { // PacketBuffers. PacketBufferEntry - // Data holds the payload of the packet. + // data holds the payload of the packet. // // For inbound packets, Data is initially the whole packet. Then gets moved to // headers via PacketHeader.Consume, when the packet is being parsed. @@ -69,7 +69,7 @@ type PacketBuffer struct { // // The bytes backing Data are immutable, a.k.a. users shouldn't write to its // backing storage. - Data buffer.VectorisedView + data buffer.VectorisedView // headers stores metadata about each header. headers [numHeaderType]headerInfo @@ -127,7 +127,7 @@ type PacketBuffer struct { // NewPacketBuffer creates a new PacketBuffer with opts. func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { pk := &PacketBuffer{ - Data: opts.Data, + data: opts.Data, } if opts.ReserveHeaderBytes != 0 { pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) @@ -184,13 +184,18 @@ func (pk *PacketBuffer) HeaderSize() int { // Size returns the size of packet in bytes. func (pk *PacketBuffer) Size() int { - return pk.HeaderSize() + pk.Data.Size() + return pk.HeaderSize() + pk.data.Size() } // MemSize returns the estimation size of the pk in memory, including backing // buffer data. func (pk *PacketBuffer) MemSize() int { - return pk.HeaderSize() + pk.Data.MemSize() + packetBufferStructSize + return pk.HeaderSize() + pk.data.MemSize() + packetBufferStructSize +} + +// Data returns the handle to data portion of pk. +func (pk *PacketBuffer) Data() PacketData { + return PacketData{pk: pk} } // Views returns the underlying storage of the whole packet. @@ -204,7 +209,7 @@ func (pk *PacketBuffer) Views() []buffer.View { } } - dataViews := pk.Data.Views() + dataViews := pk.data.Views() var vs []buffer.View if useHeader { @@ -242,11 +247,11 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum if h.buf != nil { panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) } - v, ok := pk.Data.PullUp(size) + v, ok := pk.data.PullUp(size) if !ok { return } - pk.Data.TrimFront(size) + pk.data.TrimFront(size) h.buf = v return h.buf, true } @@ -258,7 +263,7 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum func (pk *PacketBuffer) Clone() *PacketBuffer { return &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, - Data: pk.Data.Clone(nil), + data: pk.data.Clone(nil), headers: pk.headers, header: pk.header, Hash: pk.Hash, @@ -339,13 +344,234 @@ func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { return h.pk.consume(h.typ, size) } +// PacketData represents the data portion of a PacketBuffer. +type PacketData struct { + pk *PacketBuffer +} + +// PullUp returns a contiguous view of size bytes from the beginning of d. +// Callers should not write to or keep the view for later use. +func (d PacketData) PullUp(size int) (buffer.View, bool) { + return d.pk.data.PullUp(size) +} + +// TrimFront removes count from the beginning of d. It panics if count > +// d.Size(). +func (d PacketData) TrimFront(count int) { + d.pk.data.TrimFront(count) +} + +// CapLength reduces d to at most length bytes. +func (d PacketData) CapLength(length int) { + d.pk.data.CapLength(length) +} + +// Views returns the underlying storage of d in a slice of Views. Caller should +// not modify the returned slice. +func (d PacketData) Views() []buffer.View { + return d.pk.data.Views() +} + +// AppendView appends v into d, taking the ownership of v. +func (d PacketData) AppendView(v buffer.View) { + d.pk.data.AppendView(v) +} + +// ReadFromData moves at most count bytes from the beginning of srcData to the +// end of d and returns the number of bytes moved. +func (d PacketData) ReadFromData(srcData PacketData, count int) int { + return srcData.pk.data.ReadToVV(&d.pk.data, count) +} + +// ReadFromVV moves at most count bytes from the beginning of srcVV to the end +// of d and returns the number of bytes moved. +func (d PacketData) ReadFromVV(srcVV *buffer.VectorisedView, count int) int { + return srcVV.ReadToVV(&d.pk.data, count) +} + +// Size returns the number of bytes in the data payload of the packet. +func (d PacketData) Size() int { + return d.pk.data.Size() +} + +// AsRange returns a Range representing the current data payload of the packet. +func (d PacketData) AsRange() Range { + return Range{ + pk: d.pk, + offset: d.pk.HeaderSize(), + length: d.Size(), + } +} + +// ExtractVV returns a VectorisedView of d. This method has the semantic to +// destruct the underlying packet, hence the packet cannot be used again. +// +// This method exists for compatibility between PacketBuffer and VectorisedView. +// It may be removed later and should be used with care. +func (d PacketData) ExtractVV() buffer.VectorisedView { + return d.pk.data +} + +// Replace replaces the data portion of the packet with vv, taking the ownership +// of vv. +// +// This method exists for compatibility between PacketBuffer and VectorisedView. +// It may be removed later and should be used with care. +func (d PacketData) Replace(vv buffer.VectorisedView) { + d.pk.data = vv +} + +// Range represents a contiguous subportion of a PacketBuffer. +type Range struct { + pk *PacketBuffer + offset int + length int +} + +// Size returns the number of bytes in r. +func (r Range) Size() int { + return r.length +} + +// SubRange returns a new Range starting at off bytes of r. It returns an empty +// range if off is out-of-bounds. +func (r Range) SubRange(off int) Range { + if off > r.length { + return Range{pk: r.pk} + } + return Range{ + pk: r.pk, + offset: r.offset + off, + length: r.length - off, + } +} + +// Capped returns a new Range with the same starting point of r and length +// capped at max. +func (r Range) Capped(max int) Range { + if r.length <= max { + return r + } + return Range{ + pk: r.pk, + offset: r.offset, + length: max, + } +} + +// AsView returns the backing storage of r if possible. It will allocate a new +// View if r spans multiple pieces internally. Caller should not write to the +// returned View in any way. +func (r Range) AsView() buffer.View { + var allocated bool + var v buffer.View + r.iterate(func(b []byte) { + if v == nil { + // v has not been assigned, allowing first view to be returned. + v = b + } else { + // v has been assigned. This range spans more than a view, a new view + // needs to be allocated. + if !allocated { + allocated = true + all := make([]byte, 0, r.length) + all = append(all, v...) + v = all + } + v = append(v, b...) + } + }) + return v +} + +// ToOwnedView returns a owned copy of data in r. +func (r Range) ToOwnedView() buffer.View { + if r.length == 0 { + return nil + } + all := make([]byte, 0, r.length) + r.iterate(func(b []byte) { + all = append(all, b...) + }) + return all +} + +// Checksum calculates the RFC 1071 checksum for the underlying bytes of r. +func (r Range) Checksum() uint16 { + var c header.Checksumer + r.iterate(c.Add) + return c.Checksum() +} + +// iterate calls fn for each piece in r. fn is always called with a non-empty +// slice. +func (r Range) iterate(fn func([]byte)) { + w := window{ + offset: r.offset, + length: r.length, + } + // Header portion. + for i := range r.pk.headers { + if b := w.process(r.pk.headers[i].buf); len(b) > 0 { + fn(b) + } + if w.isDone() { + break + } + } + // Data portion. + if !w.isDone() { + for _, v := range r.pk.data.Views() { + if b := w.process(v); len(b) > 0 { + fn(b) + } + if w.isDone() { + break + } + } + } +} + +// window represents contiguous region of byte stream. User would call process() +// to input bytes, and obtain a subslice that is inside the window. +type window struct { + offset int + length int +} + +// isDone returns true if the window has passed and further process() calls will +// always return an empty slice. This can be used to end processing early. +func (w *window) isDone() bool { + return w.length == 0 +} + +// process feeds b in and returns a subslice that is inside the window. The +// returned slice will be a subslice of b, and it does not keep b after method +// returns. This method may return an empty slice if nothing in b is inside the +// window. +func (w *window) process(b []byte) (inWindow []byte) { + if w.offset >= len(b) { + w.offset -= len(b) + return nil + } + if w.offset > 0 { + b = b[w.offset:] + w.offset = 0 + } + if w.length < len(b) { + b = b[:w.length] + } + w.length -= len(b) + return b +} + // PayloadSince returns packet payload starting from and including a particular // header. // // The returned View is owned by the caller - its backing buffer is separate // from the packet header's underlying packet buffer. func PayloadSince(h PacketHeader) buffer.View { - size := h.pk.Data.Size() + size := h.pk.data.Size() for _, hinfo := range h.pk.headers[h.typ:] { size += len(hinfo.buf) } @@ -356,7 +582,7 @@ func PayloadSince(h PacketHeader) buffer.View { v = append(v, hinfo.buf...) } - for _, view := range h.pk.Data.Views() { + for _, view := range h.pk.data.Views() { v = append(v, view...) } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index c6fa8da5f..6728370c3 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -15,9 +15,11 @@ package stack import ( "bytes" + "fmt" "testing" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) func TestPacketHeaderPush(t *testing.T) { @@ -110,7 +112,7 @@ func TestPacketHeaderPush(t *testing.T) { if got, want := pk.Size(), allHdrSize+len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data) + checkData(t, pk, test.data) checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), concatViews(test.link, test.network, test.transport, test.data)) // Check the after values for each header. @@ -204,7 +206,7 @@ func TestPacketHeaderConsume(t *testing.T) { transport = test.data[test.link+test.network:][:test.transport] payload = test.data[allHdrSize:] ) - checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload) + checkData(t, pk, payload) checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) // Check the after values for each header. checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) @@ -340,6 +342,158 @@ func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { } } +func TestPacketBufferData(t *testing.T) { + for _, tc := range []struct { + name string + makePkt func(*testing.T) *PacketBuffer + data string + }{ + { + name: "inbound packet", + makePkt: func(*testing.T) *PacketBuffer { + pkt := NewPacketBuffer(PacketBufferOptions{ + Data: vv("aabbbbccccccDATA"), + }) + pkt.LinkHeader().Consume(2) + pkt.NetworkHeader().Consume(4) + pkt.TransportHeader().Consume(6) + return pkt + }, + data: "DATA", + }, + { + name: "outbound packet", + makePkt: func(*testing.T) *PacketBuffer { + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: 12, + Data: vv("DATA"), + }) + copy(pkt.TransportHeader().Push(6), []byte("cccccc")) + copy(pkt.NetworkHeader().Push(4), []byte("bbbb")) + copy(pkt.LinkHeader().Push(2), []byte("aa")) + return pkt + }, + data: "DATA", + }, + } { + t.Run(tc.name, func(t *testing.T) { + // PullUp + for _, n := range []int{1, len(tc.data)} { + t.Run(fmt.Sprintf("PullUp%d", n), func(t *testing.T) { + pkt := tc.makePkt(t) + v, ok := pkt.Data().PullUp(n) + wantV := []byte(tc.data)[:n] + if !ok || !bytes.Equal(v, wantV) { + t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want %q, true", n, v, ok, wantV) + } + }) + } + t.Run("PullUpOutOfBounds", func(t *testing.T) { + n := len(tc.data) + 1 + pkt := tc.makePkt(t) + v, ok := pkt.Data().PullUp(n) + if ok || v != nil { + t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want nil, false", n, v, ok) + } + }) + + // TrimFront + for _, n := range []int{1, len(tc.data)} { + t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) { + pkt := tc.makePkt(t) + pkt.Data().TrimFront(n) + + checkData(t, pkt, []byte(tc.data)[n:]) + }) + } + + // CapLength + for _, n := range []int{0, 1, len(tc.data)} { + t.Run(fmt.Sprintf("CapLength%d", n), func(t *testing.T) { + pkt := tc.makePkt(t) + pkt.Data().CapLength(n) + + want := []byte(tc.data) + if n < len(want) { + want = want[:n] + } + checkData(t, pkt, want) + }) + } + + // Views + t.Run("Views", func(t *testing.T) { + pkt := tc.makePkt(t) + checkData(t, pkt, []byte(tc.data)) + }) + + // AppendView + t.Run("AppendView", func(t *testing.T) { + s := "APPEND" + + pkt := tc.makePkt(t) + pkt.Data().AppendView(buffer.View(s)) + + checkData(t, pkt, []byte(tc.data+s)) + }) + + // ReadFromData/VV + for _, n := range []int{0, 1, 2, 7, 10, 14, 20} { + t.Run(fmt.Sprintf("ReadFromData%d", n), func(t *testing.T) { + s := "TO READ" + otherPkt := NewPacketBuffer(PacketBufferOptions{ + Data: vv(s, s), + }) + s += s + + pkt := tc.makePkt(t) + pkt.Data().ReadFromData(otherPkt.Data(), n) + + if n < len(s) { + s = s[:n] + } + checkData(t, pkt, []byte(tc.data+s)) + }) + t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) { + s := "TO READ" + srcVV := vv(s, s) + s += s + + pkt := tc.makePkt(t) + pkt.Data().ReadFromVV(&srcVV, n) + + if n < len(s) { + s = s[:n] + } + checkData(t, pkt, []byte(tc.data+s)) + }) + } + + // ExtractVV + t.Run("ExtractVV", func(t *testing.T) { + pkt := tc.makePkt(t) + extractedVV := pkt.Data().ExtractVV() + + got := extractedVV.ToOwnedView() + want := []byte(tc.data) + if !bytes.Equal(got, want) { + t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want) + } + }) + + // Replace + t.Run("Replace", func(t *testing.T) { + s := "REPLACED" + + pkt := tc.makePkt(t) + pkt.Data().Replace(vv(s)) + + checkData(t, pkt, []byte(s)) + }) + }) + } +} + func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { t.Helper() reserved := opts.ReserveHeaderBytes @@ -356,7 +510,7 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO if got, want := pk.Size(), len(data); got != want { t.Errorf("Initial pk.Size() = %d, want %d", got, want) } - checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data) + checkData(t, pk, data) checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) // Check the initial values for each header. checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) @@ -383,6 +537,70 @@ func checkViewEqual(t *testing.T, what string, got, want buffer.View) { } } +func checkData(t *testing.T, pkt *PacketBuffer, want []byte) { + t.Helper() + if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) { + t.Errorf("pkt.Data().Views() = %x, want %x", got, want) + } + if got := pkt.Data().Size(); got != len(want) { + t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want)) + } + + t.Run("AsRange", func(t *testing.T) { + // Full range + checkRange(t, pkt.Data().AsRange(), want) + + // SubRange + for _, off := range []int{0, 1, len(want), len(want) + 1} { + t.Run(fmt.Sprintf("SubRange%d", off), func(t *testing.T) { + // Empty when off is greater than the size of range. + var sub []byte + if off < len(want) { + sub = want[off:] + } + checkRange(t, pkt.Data().AsRange().SubRange(off), sub) + }) + } + + // Capped + for _, n := range []int{0, 1, len(want), len(want) + 1} { + t.Run(fmt.Sprintf("Capped%d", n), func(t *testing.T) { + sub := want + if n < len(sub) { + sub = sub[:n] + } + checkRange(t, pkt.Data().AsRange().Capped(n), sub) + }) + } + }) +} + +func checkRange(t *testing.T, r Range, data []byte) { + if got, want := r.Size(), len(data); got != want { + t.Errorf("r.Size() = %d, want %d", got, want) + } + if got := r.AsView(); !bytes.Equal(got, data) { + t.Errorf("r.AsView() = %x, want %x", got, data) + } + if got := r.ToOwnedView(); !bytes.Equal(got, data) { + t.Errorf("r.ToOwnedView() = %x, want %x", got, data) + } + if got, want := r.Checksum(), header.Checksum(data, 0 /* initial */); got != want { + t.Errorf("r.Checksum() = %x, want %x", got, want) + } +} + +func vv(pieces ...string) buffer.VectorisedView { + var views []buffer.View + var size int + for _, p := range pieces { + v := buffer.View([]byte(p)) + size += len(v) + views = append(views, v) + } + return buffer.NewVectorisedView(size, views) +} + func makeView(size int) buffer.View { b := byte(size) return bytes.Repeat([]byte{b}, size) diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 8e39e828c..f45cf5fdf 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -137,11 +137,11 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + nb, ok := pkt.Data().PullUp(fakeNetHeaderLen) if !ok { return } - pkt.Data.TrimFront(fakeNetHeaderLen) + pkt.Data().TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), @@ -4294,7 +4294,7 @@ func TestWritePacketToRemote(t *testing.T) { if pkt.Route.RemoteLinkAddress != linkAddr2 { t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2) } - if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" { + if diff := cmp.Diff(pkt.Pkt.Data().AsRange().ToOwnedView(), buffer.View(test.payload)); diff != "" { t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) } }) diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 165f73f21..405c74c65 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -675,9 +675,7 @@ func TestWritePacketsLinkResolution(t *testing.T) { Length: length, }) xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length) - for _, v := range pkt.Data.Views() { - xsum = header.Checksum(v, xsum) - } + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) pkts.PushBack(pkt) diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index e4439ba79..29266a4fc 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -75,7 +75,11 @@ func TestPingMulticastBroadcast(t *testing.T) { pkt.SetType(header.ICMPv6EchoRequest) pkt.SetCode(0) pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, utils.RemoteIPv6Addr, dst, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: utils.RemoteIPv6Addr, + Dst: dst, + })) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: header.ICMPv6MinimumSize, diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index f5e1a6e45..09e9d027d 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -417,7 +417,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - pkt.Data = data.ToVectorisedView() + pkt.Data().AppendView(data) if ttl == 0 { ttl = r.DefaultTTL() @@ -445,9 +445,15 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro return &tcpip.ErrInvalidEndpointState{} } - dataVV := data.ToVectorisedView() - icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV)) - pkt.Data = dataVV + pkt.Data().AppendView(data) + dataRange := pkt.Data().AsRange() + icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpv6, + Src: r.LocalAddress, + Dst: r.RemoteAddress, + PayloadCsum: dataRange.Checksum(), + PayloadLen: dataRange.Size(), + })) if ttl == 0 { ttl = r.DefaultTTL() @@ -763,7 +769,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB // ICMP socket's data includes ICMP header. packet.data = pkt.TransportHeader().View().ToVectorisedView() - packet.data.Append(pkt.Data) + packet.data.Append(pkt.Data().ExtractVV()) e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 73bb66830..367757d3b 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -432,7 +432,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Cooked packets can simply be queued. switch pkt.PktType { case tcpip.PacketHost: - packet.data = pkt.Data + packet.data = pkt.Data().ExtractVV() case tcpip.PacketOutgoing: // Strip Link Header. var combinedVV buffer.VectorisedView @@ -442,7 +442,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, if v := pkt.TransportHeader().View(); !v.IsEmpty() { combinedVV.AppendView(v) } - combinedVV.Append(pkt.Data) + combinedVV.Append(pkt.Data().ExtractVV()) packet.data = combinedVV default: panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt)) @@ -468,7 +468,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...) } combinedVV := linkHeader.ToVectorisedView() - combinedVV.Append(pkt.Data) + combinedVV.Append(pkt.Data().ExtractVV()) packet.data = combinedVV } else { packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views()) diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index fe8e9c751..2709be90c 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -644,7 +644,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } else { combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() } - combinedVV.Append(pkt.Data) + combinedVV.Append(pkt.Data().ExtractVV()) packet.data = combinedVV packet.timestampNS = e.stack.Clock().NowNanoseconds() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 461b1a9d7..d1e452421 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -752,7 +752,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumVV(pkt.Data, xsum) + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } } @@ -786,7 +786,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso }) pkt.Hash = tf.txHash pkt.Owner = owner - data.ReadToVV(&pkt.Data, packetSize) + pkt.Data().ReadFromVV(&data, packetSize) buildTCPHdr(r, tf, pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) pkts.PushBack(pkt) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 9ce6868df..687b9f459 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2701,7 +2701,7 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p Cause: transErr, // Linux passes the payload with the TCP header. We don't know if the TCP // header even exists, it may not for fragmented packets. - Payload: pkt.Data.ToView(), + Payload: pkt.Data().AsRange().ToOwnedView(), Dst: tcpip.FullAddress{ NIC: pkt.NICID, Addr: e.ID.RemoteAddress, diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index f27eef6a9..744382100 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -98,7 +98,7 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) * netProto: pkt.NetworkProtocolNumber, nicID: pkt.NICID, } - s.data = pkt.Data.Clone(s.views[:]) + s.data = pkt.Data().ExtractVV().Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) s.rcvdTime = time.Now() s.dataMemSize = s.data.Size() diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 807df2bb5..b519afed1 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1227,7 +1227,7 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) { netHdr := pkt.Network() xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length()) - for _, v := range pkt.Data.Views() { + for _, v := range pkt.Data().Views() { xsum = header.Checksum(v, xsum) } return hdr.CalculateChecksum(xsum) == 0xffff @@ -1240,7 +1240,7 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. hdr := header.UDP(pkt.TransportHeader().View()) - if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { + if int(hdr.Length()) > pkt.Data().Size()+header.UDPMinimumSize { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1287,10 +1287,10 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB Addr: id.LocalAddress, Port: header.UDP(hdr).DestinationPort(), }, + data: pkt.Data().ExtractVV(), } - packet.data = pkt.Data e.rcvList.PushBack(packet) - e.rcvBufSize += pkt.Data.Size() + e.rcvBufSize += packet.data.Size() // Save any useful information from the network header to the packet. switch pkt.NetworkProtocolNumber { @@ -1327,7 +1327,7 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p if e.SocketOptions().GetRecvError() { // Linux passes the payload without the UDP header. var payload []byte - udp := header.UDP(pkt.Data.ToView()) + udp := header.UDP(pkt.Data().AsRange().ToOwnedView()) if len(udp) >= header.UDPMinimumSize { payload = udp.Payload() } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 427fdd0c9..1171aeb79 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -80,7 +80,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { // protocol but don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { hdr := header.UDP(pkt.TransportHeader().View()) - if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { + if int(hdr.Length()) > pkt.Data().Size()+header.UDPMinimumSize { p.stack.Stats().UDP.MalformedPacketsReceived.Increment() return stack.UnknownDestinationPacketMalformed } diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index 19e6b8d7d..64a7a171a 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -852,7 +852,13 @@ func (l *ICMPv6) ToBytes() ([]byte, error) { if err != nil { return nil, err } - h.SetChecksum(header.ICMPv6Checksum(h, *ipv6.SrcAddr, *ipv6.DstAddr, payload)) + h.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: h, + Src: *ipv6.SrcAddr, + Dst: *ipv6.DstAddr, + PayloadCsum: header.ChecksumVV(payload, 0 /* initial */), + PayloadLen: payload.Size(), + })) break } } @@ -974,7 +980,7 @@ func (l *ICMPv4) ToBytes() ([]byte, error) { var vv buffer.VectorisedView vv.AppendView(buffer.View(l.Payload)) vv.Append(payload) - h.SetChecksum(header.ICMPv4Checksum(h, vv)) + h.SetChecksum(header.ICMPv4Checksum(h, header.ChecksumVV(vv, 0 /* initial */))) } return h, nil diff --git a/test/packetimpact/tests/ipv4_fragment_reassembly_test.go b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go index ee050e2c6..707f0f1f5 100644 --- a/test/packetimpact/tests/ipv4_fragment_reassembly_test.go +++ b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go @@ -21,7 +21,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/test/packetimpact/testbench" ) @@ -118,10 +117,7 @@ func TestIPv4FragmentReassembly(t *testing.T) { if _, err := rand.Read(originalPayload); err != nil { t.Fatalf("rand.Read: %s", err) } - cksum := header.ICMPv4Checksum( - icmp, - buffer.NewVectorisedView(len(originalPayload), []buffer.View{originalPayload}), - ) + cksum := header.ICMPv4Checksum(icmp, header.Checksum(originalPayload, 0 /* initial */)) icmp.SetChecksum(cksum) for _, fragment := range test.fragments { diff --git a/test/packetimpact/tests/ipv6_fragment_icmp_error_test.go b/test/packetimpact/tests/ipv6_fragment_icmp_error_test.go index e0b2a2178..4034a128e 100644 --- a/test/packetimpact/tests/ipv6_fragment_icmp_error_test.go +++ b/test/packetimpact/tests/ipv6_fragment_icmp_error_test.go @@ -21,7 +21,6 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/test/packetimpact/testbench" @@ -45,12 +44,13 @@ func fragmentedICMPEchoRequest(t *testing.T, n *testbench.DUTTestNet, conn *test icmpv6Header.SetCode(header.ICMPv6UnusedCode) icmpv6Header.SetIdent(0) icmpv6Header.SetSequence(0) - cksum := header.ICMPv6Checksum( - icmpv6Header, - tcpip.Address(n.LocalIPv6), - tcpip.Address(n.RemoteIPv6), - buffer.NewVectorisedView(len(payload), []buffer.View{payload}), - ) + cksum := header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpv6Header, + Src: tcpip.Address(n.LocalIPv6), + Dst: tcpip.Address(n.RemoteIPv6), + PayloadCsum: header.Checksum(payload, 0 /* initial */), + PayloadLen: len(payload), + }) icmpv6Header.SetChecksum(cksum) icmpv6Bytes := append([]byte(icmpv6Header), payload...) diff --git a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go index dd98ee7a1..db6195dc9 100644 --- a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go +++ b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go @@ -22,7 +22,6 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/test/packetimpact/testbench" ) @@ -121,12 +120,13 @@ func TestIPv6FragmentReassembly(t *testing.T) { t.Fatalf("rand.Read: %s", err) } - cksum := header.ICMPv6Checksum( - icmp, - lIP, - rIP, - buffer.NewVectorisedView(len(originalPayload), []buffer.View{originalPayload}), - ) + cksum := header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmp, + Src: lIP, + Dst: rIP, + PayloadCsum: header.Checksum(originalPayload, 0 /* initial */), + PayloadLen: len(originalPayload), + }) icmp.SetChecksum(cksum) for _, fragment := range test.fragments { |