summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
authorArthur Sfez <asfez@google.com>2020-09-15 14:47:34 -0700
committergVisor bot <gvisor-bot@google.com>2020-09-15 14:49:29 -0700
commit72a30b11486b48394fa0edca500b80e4ca83b10c (patch)
treeaa98e6266cc85773ddf7c1653c16606b87a27149 /pkg/tcpip
parent7f89a26e18edfe4335eaab965fb7a03eaebb2682 (diff)
Move reusable IPv4 test code into a testutil module and refactor it
The refactor aims to simplify the package, by replacing the Go channel with a PacketBuffer slice. This code will be reused by tests for IPv6 fragmentation. PiperOrigin-RevId: 331860411
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go144
-rw-r--r--pkg/tcpip/network/testutil/BUILD17
-rw-r--r--pkg/tcpip/network/testutil/testutil.go92
4 files changed, 143 insertions, 111 deletions
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index d142b4ffa..c82593e71 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -30,6 +30,7 @@ go_test(
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 2365b54f0..5e50558e8 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -17,8 +17,6 @@ package ipv4_test
import (
"bytes"
"encoding/hex"
- "fmt"
- "math/rand"
"testing"
"github.com/google/go-cmp/cmp"
@@ -28,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -92,31 +91,6 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
-// makeRandPkt generates a randomize packet. hdrLength indicates how much
-// data should already be in the header before WritePacket. extraLength
-// indicates how much extra space should be in the header. The payload is made
-// from many Views of the sizes listed in viewSizes.
-func makeRandPkt(hdrLength int, extraLength int, viewSizes []int) *stack.PacketBuffer {
- var views []buffer.View
- totalLength := 0
- for _, s := range viewSizes {
- newView := buffer.NewView(s)
- rand.Read(newView)
- views = append(views, newView)
- totalLength += s
- }
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: hdrLength + extraLength,
- Data: buffer.NewVectorisedView(totalLength, views),
- })
- pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
- if _, err := rand.Read(pkt.TransportHeader().Push(hdrLength)); err != nil {
- panic(fmt.Sprintf("rand.Read: %s", err))
- }
- return pkt
-}
-
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
@@ -186,63 +160,19 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
}
}
-type errorChannel struct {
- *channel.Endpoint
- Ch chan *stack.PacketBuffer
- packetCollectorErrors []*tcpip.Error
-}
-
-// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
-// will return successive errors from packetCollectorErrors until the list is
-// empty and then return nil each time.
-func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
- return &errorChannel{
- Endpoint: channel.New(size, mtu, linkAddr),
- Ch: make(chan *stack.PacketBuffer, size),
- packetCollectorErrors: packetCollectorErrors,
- }
-}
-
-// Drain removes all outbound packets from the channel and counts them.
-func (e *errorChannel) Drain() int {
- c := 0
- for {
- select {
- case <-e.Ch:
- c++
- default:
- return c
- }
- }
-}
-
-// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- select {
- case e.Ch <- pkt:
- default:
- }
-
- nextError := (*tcpip.Error)(nil)
- if len(e.packetCollectorErrors) > 0 {
- nextError = e.packetCollectorErrors[0]
- e.packetCollectorErrors = e.packetCollectorErrors[1:]
- }
- return nextError
-}
-
-type context struct {
+type testRoute struct {
stack.Route
- linkEP *errorChannel
+
+ linkEP *testutil.TestEndpoint
}
-func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
+func buildTestRoute(t *testing.T, ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) testRoute {
// Make the packet and write it.
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
})
- ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
- s.CreateNIC(1, ep)
+ testEP := testutil.NewTestEndpoint(ep, packetCollectorErrors)
+ s.CreateNIC(1, testEP)
const (
src = "\x10\x00\x00\x01"
dst = "\x10\x00\x00\x02"
@@ -262,9 +192,12 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32
if err != nil {
t.Fatalf("s.FindRoute got %v, want %v", err, nil)
}
- return context{
+ t.Cleanup(func() {
+ testEP.Close()
+ })
+ return testRoute{
Route: r,
- linkEP: ep,
+ linkEP: testEP,
}
}
@@ -274,13 +207,13 @@ func TestFragmentation(t *testing.T) {
manyPayloadViewsSizes[i] = 7
}
fragTests := []struct {
- description string
- mtu uint32
- gso *stack.GSO
- hdrLength int
- extraLength int
- payloadViewsSizes []int
- expectedFrags int
+ description string
+ mtu uint32
+ gso *stack.GSO
+ transportHeaderLength int
+ extraHeaderReserveLength int
+ payloadViewsSizes []int
+ expectedFrags int
}{
{"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
{"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
@@ -295,10 +228,10 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- pkt := makeRandPkt(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
+ r := buildTestRoute(t, channel.New(0, ft.mtu, ""), nil)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
source := pkt.Clone()
- c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: 42,
TOS: stack.DefaultTOS,
@@ -307,24 +240,13 @@ func TestFragmentation(t *testing.T) {
t.Errorf("err got %v, want %v", err, nil)
}
- var results []*stack.PacketBuffer
- L:
- for {
- select {
- case pi := <-c.linkEP.Ch:
- results = append(results, pi)
- default:
- break L
- }
- }
-
- if got, want := len(results), ft.expectedFrags; got != want {
- t.Errorf("len(result) got %d, want %d", got, want)
+ if got, want := len(r.linkEP.WrittenPackets), ft.expectedFrags; got != want {
+ t.Errorf("len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
}
- if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want {
- t.Errorf("no errors yet len(result) got %d, want %d", got, want)
+ if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
+ t.Errorf("no errors yet len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
}
- compareFragments(t, results, source, ft.mtu)
+ compareFragments(t, r.linkEP.WrittenPackets, source, ft.mtu)
})
}
}
@@ -335,21 +257,21 @@ func TestFragmentationErrors(t *testing.T) {
fragTests := []struct {
description string
mtu uint32
- hdrLength int
+ transportHeaderLength int
payloadViewsSizes []int
packetCollectorErrors []*tcpip.Error
}{
{"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
{"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
{"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
- {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+ {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
}
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- pkt := makeRandPkt(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
- c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ r := buildTestRoute(t, channel.New(0, ft.mtu, ""), ft.packetCollectorErrors)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
+ err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: 42,
TOS: stack.DefaultTOS,
@@ -364,7 +286,7 @@ func TestFragmentationErrors(t *testing.T) {
if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
t.Errorf("err got %v, want %v", got, want)
}
- if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
+ if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
}
})
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
new file mode 100644
index 000000000..e218563d0
--- /dev/null
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "testutil",
+ srcs = [
+ "testutil.go",
+ ],
+ visibility = ["//pkg/tcpip/network/ipv4:__pkg__"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
new file mode 100644
index 000000000..bf5ce74be
--- /dev/null
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -0,0 +1,92 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil defines types and functions used to test Network Layer
+// functionality such as IP fragmentation.
+package testutil
+
+import (
+ "fmt"
+ "math/rand"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// TestEndpoint is an endpoint used for testing, it stores packets written to it
+// and can mock errors.
+type TestEndpoint struct {
+ *channel.Endpoint
+
+ // WrittenPackets is where we store packets written via WritePacket().
+ WrittenPackets []*stack.PacketBuffer
+
+ packetCollectorErrors []*tcpip.Error
+}
+
+// NewTestEndpoint creates a new TestEndpoint endpoint.
+//
+// packetCollectorErrors can be used to set error values and each call to
+// WritePacket will remove the first one from the slice and return it until
+// the slice is empty - at that point it will return nil every time.
+func NewTestEndpoint(ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) *TestEndpoint {
+ return &TestEndpoint{
+ Endpoint: ep,
+ WrittenPackets: make([]*stack.PacketBuffer, 0),
+ packetCollectorErrors: packetCollectorErrors,
+ }
+}
+
+// WritePacket stores outbound packets and may return an error if one was
+// injected.
+func (e *TestEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.WrittenPackets = append(e.WrittenPackets, pkt)
+
+ if len(e.packetCollectorErrors) > 0 {
+ nextError := e.packetCollectorErrors[0]
+ e.packetCollectorErrors = e.packetCollectorErrors[1:]
+ return nextError
+ }
+
+ return nil
+}
+
+// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
+// how many random bytes will be copied in the Transport Header.
+// extraHeaderReserveLength indicates how much extra space will be reserved for
+// the other headers. The payload is made from Views of the sizes listed in
+// viewSizes.
+func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer {
+ var views buffer.VectorisedView
+
+ for _, s := range viewSizes {
+ newView := buffer.NewView(s)
+ if _, err := rand.Read(newView); err != nil {
+ panic(fmt.Sprintf("rand.Read: %s", err))
+ }
+ views.AppendView(newView)
+ }
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength,
+ Data: views,
+ })
+ pkt.NetworkProtocolNumber = proto
+ if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil {
+ panic(fmt.Sprintf("rand.Read: %s", err))
+ }
+ return pkt
+}