// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package stack import ( "bytes" "fmt" "testing" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) func TestPacketHeaderPush(t *testing.T) { for _, test := range []struct { name string reserved int link []byte network []byte transport []byte data []byte }{ { name: "construct empty packet", }, { name: "construct link header only packet", reserved: 60, link: makeView(10), }, { name: "construct link and network header only packet", reserved: 60, link: makeView(10), network: makeView(20), }, { name: "construct header only packet", reserved: 60, link: makeView(10), network: makeView(20), transport: makeView(30), }, { name: "construct data only packet", data: makeView(40), }, { name: "construct L3 packet", reserved: 60, network: makeView(20), transport: makeView(30), data: makeView(40), }, { name: "construct L2 packet", reserved: 60, link: makeView(10), network: makeView(20), transport: makeView(30), data: makeView(40), }, } { t.Run(test.name, func(t *testing.T) { pk := NewPacketBuffer(PacketBufferOptions{ ReserveHeaderBytes: test.reserved, // Make a copy of data to make sure our truth data won't be taint by // PacketBuffer. Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), }) allHdrSize := len(test.link) + len(test.network) + len(test.transport) // Check the initial values for packet. checkInitialPacketBuffer(t, pk, PacketBufferOptions{ ReserveHeaderBytes: test.reserved, Data: buffer.View(test.data).ToVectorisedView(), }) // Push headers. if v := test.transport; len(v) > 0 { copy(pk.TransportHeader().Push(len(v)), v) } if v := test.network; len(v) > 0 { copy(pk.NetworkHeader().Push(len(v)), v) } if v := test.link; len(v) > 0 { copy(pk.LinkHeader().Push(len(v)), v) } // Check the after values for packet. if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want { t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) } if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want { t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) } if got, want := pk.HeaderSize(), allHdrSize; got != want { t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) } if got, want := pk.Size(), allHdrSize+len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } checkData(t, pk, test.data) checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), concatViews(test.link, test.network, test.transport, test.data)) // Check the after values for each header. checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link) checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network) checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport) // Check the after values for PayloadSince. checkViewEqual(t, "After PayloadSince(LinkHeader)", PayloadSince(pk.LinkHeader()), concatViews(test.link, test.network, test.transport, test.data)) checkViewEqual(t, "After PayloadSince(NetworkHeader)", PayloadSince(pk.NetworkHeader()), concatViews(test.network, test.transport, test.data)) checkViewEqual(t, "After PayloadSince(TransportHeader)", PayloadSince(pk.TransportHeader()), concatViews(test.transport, test.data)) }) } } func TestPacketHeaderConsume(t *testing.T) { for _, test := range []struct { name string data []byte link int network int transport int }{ { name: "parse L2 packet", data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)), link: 10, network: 20, transport: 30, }, { name: "parse L3 packet", data: concatViews(makeView(20), makeView(30), makeView(40)), network: 20, transport: 30, }, } { t.Run(test.name, func(t *testing.T) { pk := NewPacketBuffer(PacketBufferOptions{ // Make a copy of data to make sure our truth data won't be taint by // PacketBuffer. Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), }) // Check the initial values for packet. checkInitialPacketBuffer(t, pk, PacketBufferOptions{ Data: buffer.View(test.data).ToVectorisedView(), }) // Consume headers. if size := test.link; size > 0 { if _, ok := pk.LinkHeader().Consume(size); !ok { t.Fatalf("pk.LinkHeader().Consume() = false, want true") } } if size := test.network; size > 0 { if _, ok := pk.NetworkHeader().Consume(size); !ok { t.Fatalf("pk.NetworkHeader().Consume() = false, want true") } } if size := test.transport; size > 0 { if _, ok := pk.TransportHeader().Consume(size); !ok { t.Fatalf("pk.TransportHeader().Consume() = false, want true") } } allHdrSize := test.link + test.network + test.transport // Check the after values for packet. if got, want := pk.ReservedHeaderBytes(), 0; got != want { t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) } if got, want := pk.AvailableHeaderBytes(), 0; got != want { t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) } if got, want := pk.HeaderSize(), allHdrSize; got != want { t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) } if got, want := pk.Size(), len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } // After state of pk. var ( link = test.data[:test.link] network = test.data[test.link:][:test.network] transport = test.data[test.link+test.network:][:test.transport] payload = test.data[allHdrSize:] ) checkData(t, pk, payload) checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) // Check the after values for each header. checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network) checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport) // Check the after values for PayloadSince. checkViewEqual(t, "After PayloadSince(LinkHeader)", PayloadSince(pk.LinkHeader()), concatViews(link, network, transport, payload)) checkViewEqual(t, "After PayloadSince(NetworkHeader)", PayloadSince(pk.NetworkHeader()), concatViews(network, transport, payload)) checkViewEqual(t, "After PayloadSince(TransportHeader)", PayloadSince(pk.TransportHeader()), concatViews(transport, payload)) }) } } func TestPacketHeaderConsumeDataTooShort(t *testing.T) { data := makeView(10) pk := NewPacketBuffer(PacketBufferOptions{ // Make a copy of data to make sure our truth data won't be taint by // PacketBuffer. Data: buffer.NewViewFromBytes(data).ToVectorisedView(), }) // Consume should fail if pkt.Data is too short. if _, ok := pk.LinkHeader().Consume(11); ok { t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false") } if _, ok := pk.NetworkHeader().Consume(11); ok { t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false") } if _, ok := pk.TransportHeader().Consume(11); ok { t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false") } // Check packet should look the same as initial packet. checkInitialPacketBuffer(t, pk, PacketBufferOptions{ Data: buffer.View(data).ToVectorisedView(), }) } func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { const headerSize = 10 pk := NewPacketBuffer(PacketBufferOptions{ ReserveHeaderBytes: headerSize * int(numHeaderType), }) for _, h := range []PacketHeader{ pk.TransportHeader(), pk.NetworkHeader(), pk.LinkHeader(), } { t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) { h.Push(headerSize) defer func() { recover() }() h.Push(headerSize) t.Fatal("Second push should have panicked") }) } } func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) { const headerSize = 10 pk := NewPacketBuffer(PacketBufferOptions{ Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), }) for _, h := range []PacketHeader{ pk.LinkHeader(), pk.NetworkHeader(), pk.TransportHeader(), } { t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) { if _, ok := h.Consume(headerSize); !ok { t.Fatal("First consume should succeed") } defer func() { recover() }() h.Consume(headerSize) t.Fatal("Second consume should have panicked") }) } } func TestPacketHeaderPushThenConsumePanics(t *testing.T) { const headerSize = 10 pk := NewPacketBuffer(PacketBufferOptions{ ReserveHeaderBytes: headerSize * int(numHeaderType), }) for _, h := range []PacketHeader{ pk.TransportHeader(), pk.NetworkHeader(), pk.LinkHeader(), } { t.Run(h.typ.String(), func(t *testing.T) { h.Push(headerSize) defer func() { recover() }() h.Consume(headerSize) t.Fatal("Consume should have panicked") }) } } func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { const headerSize = 10 pk := NewPacketBuffer(PacketBufferOptions{ Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), }) for _, h := range []PacketHeader{ pk.LinkHeader(), pk.NetworkHeader(), pk.TransportHeader(), } { t.Run(h.typ.String(), func(t *testing.T) { h.Consume(headerSize) defer func() { recover() }() h.Push(headerSize) t.Fatal("Push should have panicked") }) } } func TestPacketBufferData(t *testing.T) { for _, tc := range []struct { name string makePkt func(*testing.T) *PacketBuffer data string }{ { name: "inbound packet", makePkt: func(*testing.T) *PacketBuffer { pkt := NewPacketBuffer(PacketBufferOptions{ Data: vv("aabbbbccccccDATA"), }) pkt.LinkHeader().Consume(2) pkt.NetworkHeader().Consume(4) pkt.TransportHeader().Consume(6) return pkt }, data: "DATA", }, { name: "outbound packet", makePkt: func(*testing.T) *PacketBuffer { pkt := NewPacketBuffer(PacketBufferOptions{ ReserveHeaderBytes: 12, Data: vv("DATA"), }) copy(pkt.TransportHeader().Push(6), []byte("cccccc")) copy(pkt.NetworkHeader().Push(4), []byte("bbbb")) copy(pkt.LinkHeader().Push(2), []byte("aa")) return pkt }, data: "DATA", }, } { t.Run(tc.name, func(t *testing.T) { // PullUp for _, n := range []int{1, len(tc.data)} { t.Run(fmt.Sprintf("PullUp%d", n), func(t *testing.T) { pkt := tc.makePkt(t) v, ok := pkt.Data().PullUp(n) wantV := []byte(tc.data)[:n] if !ok || !bytes.Equal(v, wantV) { t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want %q, true", n, v, ok, wantV) } }) } t.Run("PullUpOutOfBounds", func(t *testing.T) { n := len(tc.data) + 1 pkt := tc.makePkt(t) v, ok := pkt.Data().PullUp(n) if ok || v != nil { t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want nil, false", n, v, ok) } }) // TrimFront for _, n := range []int{1, len(tc.data)} { t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) { pkt := tc.makePkt(t) pkt.Data().TrimFront(n) checkData(t, pkt, []byte(tc.data)[n:]) }) } // CapLength for _, n := range []int{0, 1, len(tc.data)} { t.Run(fmt.Sprintf("CapLength%d", n), func(t *testing.T) { pkt := tc.makePkt(t) pkt.Data().CapLength(n) want := []byte(tc.data) if n < len(want) { want = want[:n] } checkData(t, pkt, want) }) } // Views t.Run("Views", func(t *testing.T) { pkt := tc.makePkt(t) checkData(t, pkt, []byte(tc.data)) }) // AppendView t.Run("AppendView", func(t *testing.T) { s := "APPEND" pkt := tc.makePkt(t) pkt.Data().AppendView(buffer.View(s)) checkData(t, pkt, []byte(tc.data+s)) }) // ReadFromData/VV for _, n := range []int{0, 1, 2, 7, 10, 14, 20} { t.Run(fmt.Sprintf("ReadFromData%d", n), func(t *testing.T) { s := "TO READ" otherPkt := NewPacketBuffer(PacketBufferOptions{ Data: vv(s, s), }) s += s pkt := tc.makePkt(t) pkt.Data().ReadFromData(otherPkt.Data(), n) if n < len(s) { s = s[:n] } checkData(t, pkt, []byte(tc.data+s)) }) t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) { s := "TO READ" srcVV := vv(s, s) s += s pkt := tc.makePkt(t) pkt.Data().ReadFromVV(&srcVV, n) if n < len(s) { s = s[:n] } checkData(t, pkt, []byte(tc.data+s)) }) } // ExtractVV t.Run("ExtractVV", func(t *testing.T) { pkt := tc.makePkt(t) extractedVV := pkt.Data().ExtractVV() got := extractedVV.ToOwnedView() want := []byte(tc.data) if !bytes.Equal(got, want) { t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want) } }) // Replace t.Run("Replace", func(t *testing.T) { s := "REPLACED" pkt := tc.makePkt(t) pkt.Data().Replace(vv(s)) checkData(t, pkt, []byte(s)) }) }) } } func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { t.Helper() reserved := opts.ReserveHeaderBytes if got, want := pk.ReservedHeaderBytes(), reserved; got != want { t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want) } if got, want := pk.AvailableHeaderBytes(), reserved; got != want { t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want) } if got, want := pk.HeaderSize(), 0; got != want { t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want) } data := opts.Data.ToView() if got, want := pk.Size(), len(data); got != want { t.Errorf("Initial pk.Size() = %d, want %d", got, want) } checkData(t, pk, data) checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) // Check the initial values for each header. checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil) checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil) // Check the initial valies for PayloadSince. checkViewEqual(t, "Initial PayloadSince(LinkHeader)", PayloadSince(pk.LinkHeader()), data) checkViewEqual(t, "Initial PayloadSince(NetworkHeader)", PayloadSince(pk.NetworkHeader()), data) checkViewEqual(t, "Initial PayloadSince(TransportHeader)", PayloadSince(pk.TransportHeader()), data) } func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { t.Helper() checkViewEqual(t, name+".View()", h.View(), want) } func checkViewEqual(t *testing.T, what string, got, want buffer.View) { t.Helper() if !bytes.Equal(got, want) { t.Errorf("%s = %x, want %x", what, got, want) } } func checkData(t *testing.T, pkt *PacketBuffer, want []byte) { t.Helper() if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) { t.Errorf("pkt.Data().Views() = %x, want %x", got, want) } if got := pkt.Data().Size(); got != len(want) { t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want)) } t.Run("AsRange", func(t *testing.T) { // Full range checkRange(t, pkt.Data().AsRange(), want) // SubRange for _, off := range []int{0, 1, len(want), len(want) + 1} { t.Run(fmt.Sprintf("SubRange%d", off), func(t *testing.T) { // Empty when off is greater than the size of range. var sub []byte if off < len(want) { sub = want[off:] } checkRange(t, pkt.Data().AsRange().SubRange(off), sub) }) } // Capped for _, n := range []int{0, 1, len(want), len(want) + 1} { t.Run(fmt.Sprintf("Capped%d", n), func(t *testing.T) { sub := want if n < len(sub) { sub = sub[:n] } checkRange(t, pkt.Data().AsRange().Capped(n), sub) }) } }) } func checkRange(t *testing.T, r Range, data []byte) { if got, want := r.Size(), len(data); got != want { t.Errorf("r.Size() = %d, want %d", got, want) } if got := r.AsView(); !bytes.Equal(got, data) { t.Errorf("r.AsView() = %x, want %x", got, data) } if got := r.ToOwnedView(); !bytes.Equal(got, data) { t.Errorf("r.ToOwnedView() = %x, want %x", got, data) } if got, want := r.Checksum(), header.Checksum(data, 0 /* initial */); got != want { t.Errorf("r.Checksum() = %x, want %x", got, want) } } func vv(pieces ...string) buffer.VectorisedView { var views []buffer.View var size int for _, p := range pieces { v := buffer.View([]byte(p)) size += len(v) views = append(views, v) } return buffer.NewVectorisedView(size, views) } func makeView(size int) buffer.View { b := byte(size) return bytes.Repeat([]byte{b}, size) } func concatViews(views ...buffer.View) buffer.View { var all buffer.View for _, v := range views { all = append(all, v...) } return all }