From 72a30b11486b48394fa0edca500b80e4ca83b10c Mon Sep 17 00:00:00 2001 From: Arthur Sfez <asfez@google.com> Date: Tue, 15 Sep 2020 14:47:34 -0700 Subject: Move reusable IPv4 test code into a testutil module and refactor it The refactor aims to simplify the package, by replacing the Go channel with a PacketBuffer slice. This code will be reused by tests for IPv6 fragmentation. PiperOrigin-RevId: 331860411 --- pkg/tcpip/network/ipv4/BUILD | 1 + pkg/tcpip/network/ipv4/ipv4_test.go | 144 ++++++++------------------------- pkg/tcpip/network/testutil/BUILD | 17 ++++ pkg/tcpip/network/testutil/testutil.go | 92 +++++++++++++++++++++ 4 files changed, 143 insertions(+), 111 deletions(-) create mode 100644 pkg/tcpip/network/testutil/BUILD create mode 100644 pkg/tcpip/network/testutil/testutil.go (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index d142b4ffa..c82593e71 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -30,6 +30,7 @@ go_test( "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 2365b54f0..5e50558e8 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -17,8 +17,6 @@ package ipv4_test import ( "bytes" "encoding/hex" - "fmt" - "math/rand" "testing" "github.com/google/go-cmp/cmp" @@ -28,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -92,31 +91,6 @@ func TestExcludeBroadcast(t *testing.T) { }) } -// makeRandPkt generates a randomize packet. hdrLength indicates how much -// data should already be in the header before WritePacket. extraLength -// indicates how much extra space should be in the header. The payload is made -// from many Views of the sizes listed in viewSizes. -func makeRandPkt(hdrLength int, extraLength int, viewSizes []int) *stack.PacketBuffer { - var views []buffer.View - totalLength := 0 - for _, s := range viewSizes { - newView := buffer.NewView(s) - rand.Read(newView) - views = append(views, newView) - totalLength += s - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: hdrLength + extraLength, - Data: buffer.NewVectorisedView(totalLength, views), - }) - pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - if _, err := rand.Read(pkt.TransportHeader().Push(hdrLength)); err != nil { - panic(fmt.Sprintf("rand.Read: %s", err)) - } - return pkt -} - // comparePayloads compared the contents of all the packets against the contents // of the source packet. func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { @@ -186,63 +160,19 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI } } -type errorChannel struct { - *channel.Endpoint - Ch chan *stack.PacketBuffer - packetCollectorErrors []*tcpip.Error -} - -// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket -// will return successive errors from packetCollectorErrors until the list is -// empty and then return nil each time. -func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { - return &errorChannel{ - Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan *stack.PacketBuffer, size), - packetCollectorErrors: packetCollectorErrors, - } -} - -// Drain removes all outbound packets from the channel and counts them. -func (e *errorChannel) Drain() int { - c := 0 - for { - select { - case <-e.Ch: - c++ - default: - return c - } - } -} - -// WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - select { - case e.Ch <- pkt: - default: - } - - nextError := (*tcpip.Error)(nil) - if len(e.packetCollectorErrors) > 0 { - nextError = e.packetCollectorErrors[0] - e.packetCollectorErrors = e.packetCollectorErrors[1:] - } - return nextError -} - -type context struct { +type testRoute struct { stack.Route - linkEP *errorChannel + + linkEP *testutil.TestEndpoint } -func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { +func buildTestRoute(t *testing.T, ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) testRoute { // Make the packet and write it. s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, }) - ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) - s.CreateNIC(1, ep) + testEP := testutil.NewTestEndpoint(ep, packetCollectorErrors) + s.CreateNIC(1, testEP) const ( src = "\x10\x00\x00\x01" dst = "\x10\x00\x00\x02" @@ -262,9 +192,12 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32 if err != nil { t.Fatalf("s.FindRoute got %v, want %v", err, nil) } - return context{ + t.Cleanup(func() { + testEP.Close() + }) + return testRoute{ Route: r, - linkEP: ep, + linkEP: testEP, } } @@ -274,13 +207,13 @@ func TestFragmentation(t *testing.T) { manyPayloadViewsSizes[i] = 7 } fragTests := []struct { - description string - mtu uint32 - gso *stack.GSO - hdrLength int - extraLength int - payloadViewsSizes []int - expectedFrags int + description string + mtu uint32 + gso *stack.GSO + transportHeaderLength int + extraHeaderReserveLength int + payloadViewsSizes []int + expectedFrags int }{ {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, @@ -295,10 +228,10 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - pkt := makeRandPkt(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) + r := buildTestRoute(t, channel.New(0, ft.mtu, ""), nil) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber) source := pkt.Clone() - c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ + err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, @@ -307,24 +240,13 @@ func TestFragmentation(t *testing.T) { t.Errorf("err got %v, want %v", err, nil) } - var results []*stack.PacketBuffer - L: - for { - select { - case pi := <-c.linkEP.Ch: - results = append(results, pi) - default: - break L - } - } - - if got, want := len(results), ft.expectedFrags; got != want { - t.Errorf("len(result) got %d, want %d", got, want) + if got, want := len(r.linkEP.WrittenPackets), ft.expectedFrags; got != want { + t.Errorf("len(r.linkEP.WrittenPackets) got %d, want %d", got, want) } - if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want { - t.Errorf("no errors yet len(result) got %d, want %d", got, want) + if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { + t.Errorf("no errors yet len(r.linkEP.WrittenPackets) got %d, want %d", got, want) } - compareFragments(t, results, source, ft.mtu) + compareFragments(t, r.linkEP.WrittenPackets, source, ft.mtu) }) } } @@ -335,21 +257,21 @@ func TestFragmentationErrors(t *testing.T) { fragTests := []struct { description string mtu uint32 - hdrLength int + transportHeaderLength int payloadViewsSizes []int packetCollectorErrors []*tcpip.Error }{ {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}}, - {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, + {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, } for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - pkt := makeRandPkt(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) - c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + r := buildTestRoute(t, channel.New(0, ft.mtu, ""), ft.packetCollectorErrors) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber) + err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, @@ -364,7 +286,7 @@ func TestFragmentationErrors(t *testing.T) { if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want { t.Errorf("err got %v, want %v", got, want) } - if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { + if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { t.Errorf("after linkEP error len(result) got %d, want %d", got, want) } }) diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD new file mode 100644 index 000000000..e218563d0 --- /dev/null +++ b/pkg/tcpip/network/testutil/BUILD @@ -0,0 +1,17 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "testutil", + srcs = [ + "testutil.go", + ], + visibility = ["//pkg/tcpip/network/ipv4:__pkg__"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go new file mode 100644 index 000000000..bf5ce74be --- /dev/null +++ b/pkg/tcpip/network/testutil/testutil.go @@ -0,0 +1,92 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testutil defines types and functions used to test Network Layer +// functionality such as IP fragmentation. +package testutil + +import ( + "fmt" + "math/rand" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// TestEndpoint is an endpoint used for testing, it stores packets written to it +// and can mock errors. +type TestEndpoint struct { + *channel.Endpoint + + // WrittenPackets is where we store packets written via WritePacket(). + WrittenPackets []*stack.PacketBuffer + + packetCollectorErrors []*tcpip.Error +} + +// NewTestEndpoint creates a new TestEndpoint endpoint. +// +// packetCollectorErrors can be used to set error values and each call to +// WritePacket will remove the first one from the slice and return it until +// the slice is empty - at that point it will return nil every time. +func NewTestEndpoint(ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) *TestEndpoint { + return &TestEndpoint{ + Endpoint: ep, + WrittenPackets: make([]*stack.PacketBuffer, 0), + packetCollectorErrors: packetCollectorErrors, + } +} + +// WritePacket stores outbound packets and may return an error if one was +// injected. +func (e *TestEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.WrittenPackets = append(e.WrittenPackets, pkt) + + if len(e.packetCollectorErrors) > 0 { + nextError := e.packetCollectorErrors[0] + e.packetCollectorErrors = e.packetCollectorErrors[1:] + return nextError + } + + return nil +} + +// MakeRandPkt generates a randomized packet. transportHeaderLength indicates +// how many random bytes will be copied in the Transport Header. +// extraHeaderReserveLength indicates how much extra space will be reserved for +// the other headers. The payload is made from Views of the sizes listed in +// viewSizes. +func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer { + var views buffer.VectorisedView + + for _, s := range viewSizes { + newView := buffer.NewView(s) + if _, err := rand.Read(newView); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + views.AppendView(newView) + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength, + Data: views, + }) + pkt.NetworkProtocolNumber = proto + if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + return pkt +} -- cgit v1.2.3