summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network/ipv4
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go144
2 files changed, 34 insertions, 111 deletions
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index d142b4ffa..c82593e71 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -30,6 +30,7 @@ go_test(
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 2365b54f0..5e50558e8 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -17,8 +17,6 @@ package ipv4_test
import (
"bytes"
"encoding/hex"
- "fmt"
- "math/rand"
"testing"
"github.com/google/go-cmp/cmp"
@@ -28,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -92,31 +91,6 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
-// makeRandPkt generates a randomize packet. hdrLength indicates how much
-// data should already be in the header before WritePacket. extraLength
-// indicates how much extra space should be in the header. The payload is made
-// from many Views of the sizes listed in viewSizes.
-func makeRandPkt(hdrLength int, extraLength int, viewSizes []int) *stack.PacketBuffer {
- var views []buffer.View
- totalLength := 0
- for _, s := range viewSizes {
- newView := buffer.NewView(s)
- rand.Read(newView)
- views = append(views, newView)
- totalLength += s
- }
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: hdrLength + extraLength,
- Data: buffer.NewVectorisedView(totalLength, views),
- })
- pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
- if _, err := rand.Read(pkt.TransportHeader().Push(hdrLength)); err != nil {
- panic(fmt.Sprintf("rand.Read: %s", err))
- }
- return pkt
-}
-
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
@@ -186,63 +160,19 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
}
}
-type errorChannel struct {
- *channel.Endpoint
- Ch chan *stack.PacketBuffer
- packetCollectorErrors []*tcpip.Error
-}
-
-// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
-// will return successive errors from packetCollectorErrors until the list is
-// empty and then return nil each time.
-func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
- return &errorChannel{
- Endpoint: channel.New(size, mtu, linkAddr),
- Ch: make(chan *stack.PacketBuffer, size),
- packetCollectorErrors: packetCollectorErrors,
- }
-}
-
-// Drain removes all outbound packets from the channel and counts them.
-func (e *errorChannel) Drain() int {
- c := 0
- for {
- select {
- case <-e.Ch:
- c++
- default:
- return c
- }
- }
-}
-
-// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- select {
- case e.Ch <- pkt:
- default:
- }
-
- nextError := (*tcpip.Error)(nil)
- if len(e.packetCollectorErrors) > 0 {
- nextError = e.packetCollectorErrors[0]
- e.packetCollectorErrors = e.packetCollectorErrors[1:]
- }
- return nextError
-}
-
-type context struct {
+type testRoute struct {
stack.Route
- linkEP *errorChannel
+
+ linkEP *testutil.TestEndpoint
}
-func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
+func buildTestRoute(t *testing.T, ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) testRoute {
// Make the packet and write it.
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
})
- ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
- s.CreateNIC(1, ep)
+ testEP := testutil.NewTestEndpoint(ep, packetCollectorErrors)
+ s.CreateNIC(1, testEP)
const (
src = "\x10\x00\x00\x01"
dst = "\x10\x00\x00\x02"
@@ -262,9 +192,12 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32
if err != nil {
t.Fatalf("s.FindRoute got %v, want %v", err, nil)
}
- return context{
+ t.Cleanup(func() {
+ testEP.Close()
+ })
+ return testRoute{
Route: r,
- linkEP: ep,
+ linkEP: testEP,
}
}
@@ -274,13 +207,13 @@ func TestFragmentation(t *testing.T) {
manyPayloadViewsSizes[i] = 7
}
fragTests := []struct {
- description string
- mtu uint32
- gso *stack.GSO
- hdrLength int
- extraLength int
- payloadViewsSizes []int
- expectedFrags int
+ description string
+ mtu uint32
+ gso *stack.GSO
+ transportHeaderLength int
+ extraHeaderReserveLength int
+ payloadViewsSizes []int
+ expectedFrags int
}{
{"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
{"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
@@ -295,10 +228,10 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- pkt := makeRandPkt(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
+ r := buildTestRoute(t, channel.New(0, ft.mtu, ""), nil)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
source := pkt.Clone()
- c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: 42,
TOS: stack.DefaultTOS,
@@ -307,24 +240,13 @@ func TestFragmentation(t *testing.T) {
t.Errorf("err got %v, want %v", err, nil)
}
- var results []*stack.PacketBuffer
- L:
- for {
- select {
- case pi := <-c.linkEP.Ch:
- results = append(results, pi)
- default:
- break L
- }
- }
-
- if got, want := len(results), ft.expectedFrags; got != want {
- t.Errorf("len(result) got %d, want %d", got, want)
+ if got, want := len(r.linkEP.WrittenPackets), ft.expectedFrags; got != want {
+ t.Errorf("len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
}
- if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want {
- t.Errorf("no errors yet len(result) got %d, want %d", got, want)
+ if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
+ t.Errorf("no errors yet len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
}
- compareFragments(t, results, source, ft.mtu)
+ compareFragments(t, r.linkEP.WrittenPackets, source, ft.mtu)
})
}
}
@@ -335,21 +257,21 @@ func TestFragmentationErrors(t *testing.T) {
fragTests := []struct {
description string
mtu uint32
- hdrLength int
+ transportHeaderLength int
payloadViewsSizes []int
packetCollectorErrors []*tcpip.Error
}{
{"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
{"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
{"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
- {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+ {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
}
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- pkt := makeRandPkt(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
- c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ r := buildTestRoute(t, channel.New(0, ft.mtu, ""), ft.packetCollectorErrors)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
+ err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: 42,
TOS: stack.DefaultTOS,
@@ -364,7 +286,7 @@ func TestFragmentationErrors(t *testing.T) {
if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
t.Errorf("err got %v, want %v", got, want)
}
- if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
+ if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
}
})