From 72a30b11486b48394fa0edca500b80e4ca83b10c Mon Sep 17 00:00:00 2001
From: Arthur Sfez <asfez@google.com>
Date: Tue, 15 Sep 2020 14:47:34 -0700
Subject: Move reusable IPv4 test code into a testutil module and refactor it

The refactor aims to simplify the package, by replacing the Go channel with a
PacketBuffer slice.

This code will be reused by tests for IPv6 fragmentation.

PiperOrigin-RevId: 331860411
---
 pkg/tcpip/network/ipv4/BUILD           |   1 +
 pkg/tcpip/network/ipv4/ipv4_test.go    | 144 ++++++++-------------------------
 pkg/tcpip/network/testutil/BUILD       |  17 ++++
 pkg/tcpip/network/testutil/testutil.go |  92 +++++++++++++++++++++
 4 files changed, 143 insertions(+), 111 deletions(-)
 create mode 100644 pkg/tcpip/network/testutil/BUILD
 create mode 100644 pkg/tcpip/network/testutil/testutil.go

(limited to 'pkg/tcpip')

diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index d142b4ffa..c82593e71 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -30,6 +30,7 @@ go_test(
         "//pkg/tcpip/link/channel",
         "//pkg/tcpip/link/sniffer",
         "//pkg/tcpip/network/ipv4",
+        "//pkg/tcpip/network/testutil",
         "//pkg/tcpip/stack",
         "//pkg/tcpip/transport/tcp",
         "//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 2365b54f0..5e50558e8 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -17,8 +17,6 @@ package ipv4_test
 import (
 	"bytes"
 	"encoding/hex"
-	"fmt"
-	"math/rand"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
@@ -28,6 +26,7 @@ import (
 	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
 	"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
 	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+	"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
 	"gvisor.dev/gvisor/pkg/tcpip/stack"
 	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
 	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -92,31 +91,6 @@ func TestExcludeBroadcast(t *testing.T) {
 	})
 }
 
-// makeRandPkt generates a randomize packet. hdrLength indicates how much
-// data should already be in the header before WritePacket. extraLength
-// indicates how much extra space should be in the header. The payload is made
-// from many Views of the sizes listed in viewSizes.
-func makeRandPkt(hdrLength int, extraLength int, viewSizes []int) *stack.PacketBuffer {
-	var views []buffer.View
-	totalLength := 0
-	for _, s := range viewSizes {
-		newView := buffer.NewView(s)
-		rand.Read(newView)
-		views = append(views, newView)
-		totalLength += s
-	}
-
-	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
-		ReserveHeaderBytes: hdrLength + extraLength,
-		Data:               buffer.NewVectorisedView(totalLength, views),
-	})
-	pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
-	if _, err := rand.Read(pkt.TransportHeader().Push(hdrLength)); err != nil {
-		panic(fmt.Sprintf("rand.Read: %s", err))
-	}
-	return pkt
-}
-
 // comparePayloads compared the contents of all the packets against the contents
 // of the source packet.
 func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
@@ -186,63 +160,19 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
 	}
 }
 
-type errorChannel struct {
-	*channel.Endpoint
-	Ch                    chan *stack.PacketBuffer
-	packetCollectorErrors []*tcpip.Error
-}
-
-// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
-// will return successive errors from packetCollectorErrors until the list is
-// empty and then return nil each time.
-func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
-	return &errorChannel{
-		Endpoint:              channel.New(size, mtu, linkAddr),
-		Ch:                    make(chan *stack.PacketBuffer, size),
-		packetCollectorErrors: packetCollectorErrors,
-	}
-}
-
-// Drain removes all outbound packets from the channel and counts them.
-func (e *errorChannel) Drain() int {
-	c := 0
-	for {
-		select {
-		case <-e.Ch:
-			c++
-		default:
-			return c
-		}
-	}
-}
-
-// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
-	select {
-	case e.Ch <- pkt:
-	default:
-	}
-
-	nextError := (*tcpip.Error)(nil)
-	if len(e.packetCollectorErrors) > 0 {
-		nextError = e.packetCollectorErrors[0]
-		e.packetCollectorErrors = e.packetCollectorErrors[1:]
-	}
-	return nextError
-}
-
-type context struct {
+type testRoute struct {
 	stack.Route
-	linkEP *errorChannel
+
+	linkEP *testutil.TestEndpoint
 }
 
-func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
+func buildTestRoute(t *testing.T, ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) testRoute {
 	// Make the packet and write it.
 	s := stack.New(stack.Options{
 		NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
 	})
-	ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
-	s.CreateNIC(1, ep)
+	testEP := testutil.NewTestEndpoint(ep, packetCollectorErrors)
+	s.CreateNIC(1, testEP)
 	const (
 		src = "\x10\x00\x00\x01"
 		dst = "\x10\x00\x00\x02"
@@ -262,9 +192,12 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32
 	if err != nil {
 		t.Fatalf("s.FindRoute got %v, want %v", err, nil)
 	}
-	return context{
+	t.Cleanup(func() {
+		testEP.Close()
+	})
+	return testRoute{
 		Route:  r,
-		linkEP: ep,
+		linkEP: testEP,
 	}
 }
 
@@ -274,13 +207,13 @@ func TestFragmentation(t *testing.T) {
 		manyPayloadViewsSizes[i] = 7
 	}
 	fragTests := []struct {
-		description       string
-		mtu               uint32
-		gso               *stack.GSO
-		hdrLength         int
-		extraLength       int
-		payloadViewsSizes []int
-		expectedFrags     int
+		description              string
+		mtu                      uint32
+		gso                      *stack.GSO
+		transportHeaderLength    int
+		extraHeaderReserveLength int
+		payloadViewsSizes        []int
+		expectedFrags            int
 	}{
 		{"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
 		{"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
@@ -295,10 +228,10 @@ func TestFragmentation(t *testing.T) {
 
 	for _, ft := range fragTests {
 		t.Run(ft.description, func(t *testing.T) {
-			pkt := makeRandPkt(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
+			r := buildTestRoute(t, channel.New(0, ft.mtu, ""), nil)
+			pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
 			source := pkt.Clone()
-			c := buildContext(t, nil, ft.mtu)
-			err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{
+			err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
 				Protocol: tcp.ProtocolNumber,
 				TTL:      42,
 				TOS:      stack.DefaultTOS,
@@ -307,24 +240,13 @@ func TestFragmentation(t *testing.T) {
 				t.Errorf("err got %v, want %v", err, nil)
 			}
 
-			var results []*stack.PacketBuffer
-		L:
-			for {
-				select {
-				case pi := <-c.linkEP.Ch:
-					results = append(results, pi)
-				default:
-					break L
-				}
-			}
-
-			if got, want := len(results), ft.expectedFrags; got != want {
-				t.Errorf("len(result) got %d, want %d", got, want)
+			if got, want := len(r.linkEP.WrittenPackets), ft.expectedFrags; got != want {
+				t.Errorf("len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
 			}
-			if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want {
-				t.Errorf("no errors yet len(result) got %d, want %d", got, want)
+			if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
+				t.Errorf("no errors yet len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
 			}
-			compareFragments(t, results, source, ft.mtu)
+			compareFragments(t, r.linkEP.WrittenPackets, source, ft.mtu)
 		})
 	}
 }
@@ -335,21 +257,21 @@ func TestFragmentationErrors(t *testing.T) {
 	fragTests := []struct {
 		description           string
 		mtu                   uint32
-		hdrLength             int
+		transportHeaderLength int
 		payloadViewsSizes     []int
 		packetCollectorErrors []*tcpip.Error
 	}{
 		{"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
 		{"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
 		{"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
-		{"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+		{"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
 	}
 
 	for _, ft := range fragTests {
 		t.Run(ft.description, func(t *testing.T) {
-			pkt := makeRandPkt(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
-			c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
-			err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+			r := buildTestRoute(t, channel.New(0, ft.mtu, ""), ft.packetCollectorErrors)
+			pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
+			err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
 				Protocol: tcp.ProtocolNumber,
 				TTL:      42,
 				TOS:      stack.DefaultTOS,
@@ -364,7 +286,7 @@ func TestFragmentationErrors(t *testing.T) {
 			if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
 				t.Errorf("err got %v, want %v", got, want)
 			}
-			if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
+			if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
 				t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
 			}
 		})
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
new file mode 100644
index 000000000..e218563d0
--- /dev/null
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+    name = "testutil",
+    srcs = [
+        "testutil.go",
+    ],
+    visibility = ["//pkg/tcpip/network/ipv4:__pkg__"],
+    deps = [
+        "//pkg/tcpip",
+        "//pkg/tcpip/buffer",
+        "//pkg/tcpip/link/channel",
+        "//pkg/tcpip/stack",
+    ],
+)
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
new file mode 100644
index 000000000..bf5ce74be
--- /dev/null
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -0,0 +1,92 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil defines types and functions used to test Network Layer
+// functionality such as IP fragmentation.
+package testutil
+
+import (
+	"fmt"
+	"math/rand"
+
+	"gvisor.dev/gvisor/pkg/tcpip"
+	"gvisor.dev/gvisor/pkg/tcpip/buffer"
+	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+	"gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// TestEndpoint is an endpoint used for testing, it stores packets written to it
+// and can mock errors.
+type TestEndpoint struct {
+	*channel.Endpoint
+
+	// WrittenPackets is where we store packets written via WritePacket().
+	WrittenPackets []*stack.PacketBuffer
+
+	packetCollectorErrors []*tcpip.Error
+}
+
+// NewTestEndpoint creates a new TestEndpoint endpoint.
+//
+// packetCollectorErrors can be used to set error values and each call to
+// WritePacket will remove the first one from the slice and return it until
+// the slice is empty - at that point it will return nil every time.
+func NewTestEndpoint(ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) *TestEndpoint {
+	return &TestEndpoint{
+		Endpoint:              ep,
+		WrittenPackets:        make([]*stack.PacketBuffer, 0),
+		packetCollectorErrors: packetCollectorErrors,
+	}
+}
+
+// WritePacket stores outbound packets and may return an error if one was
+// injected.
+func (e *TestEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+	e.WrittenPackets = append(e.WrittenPackets, pkt)
+
+	if len(e.packetCollectorErrors) > 0 {
+		nextError := e.packetCollectorErrors[0]
+		e.packetCollectorErrors = e.packetCollectorErrors[1:]
+		return nextError
+	}
+
+	return nil
+}
+
+// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
+// how many random bytes will be copied in the Transport Header.
+// extraHeaderReserveLength indicates how much extra space will be reserved for
+// the other headers. The payload is made from Views of the sizes listed in
+// viewSizes.
+func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer {
+	var views buffer.VectorisedView
+
+	for _, s := range viewSizes {
+		newView := buffer.NewView(s)
+		if _, err := rand.Read(newView); err != nil {
+			panic(fmt.Sprintf("rand.Read: %s", err))
+		}
+		views.AppendView(newView)
+	}
+
+	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+		ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength,
+		Data:               views,
+	})
+	pkt.NetworkProtocolNumber = proto
+	if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil {
+		panic(fmt.Sprintf("rand.Read: %s", err))
+	}
+	return pkt
+}
-- 
cgit v1.2.3