From cf3cef1171bdfb41a27d563eb368d4488e0b99f1 Mon Sep 17 00:00:00 2001
From: Arthur Sfez <asfez@google.com>
Date: Tue, 22 Sep 2020 15:04:11 -0700
Subject: Refactor testutil.TestEndpoint and use it instead of limitedEP

The new testutil.MockLinkEndpoint implementation is not composed by
channel.Channel anymore because none of its features were used.

PiperOrigin-RevId: 333167753
---
 pkg/tcpip/network/ipv4/ipv4_test.go    | 187 +++++++--------------------------
 pkg/tcpip/network/ipv6/BUILD           |   1 +
 pkg/tcpip/network/ipv6/ipv6_test.go    |  97 ++++-------------
 pkg/tcpip/network/testutil/BUILD       |   7 +-
 pkg/tcpip/network/testutil/testutil.go | 102 +++++++++++++-----
 5 files changed, 141 insertions(+), 253 deletions(-)

(limited to 'pkg/tcpip')

diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index b14bc98e8..86187aba8 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -17,6 +17,7 @@ package ipv4_test
 import (
 	"bytes"
 	"encoding/hex"
+	"math"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
@@ -160,47 +161,6 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
 	}
 }
 
-type testRoute struct {
-	stack.Route
-
-	linkEP *testutil.TestEndpoint
-}
-
-func buildTestRoute(t *testing.T, ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) testRoute {
-	// Make the packet and write it.
-	s := stack.New(stack.Options{
-		NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
-	})
-	testEP := testutil.NewTestEndpoint(ep, packetCollectorErrors)
-	s.CreateNIC(1, testEP)
-	const (
-		src = "\x10\x00\x00\x01"
-		dst = "\x10\x00\x00\x02"
-	)
-	s.AddAddress(1, ipv4.ProtocolNumber, src)
-	{
-		subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
-		if err != nil {
-			t.Fatal(err)
-		}
-		s.SetRouteTable([]tcpip.Route{{
-			Destination: subnet,
-			NIC:         1,
-		}})
-	}
-	r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
-	if err != nil {
-		t.Fatalf("s.FindRoute got %v, want %v", err, nil)
-	}
-	t.Cleanup(func() {
-		testEP.Close()
-	})
-	return testRoute{
-		Route:  r,
-		linkEP: testEP,
-	}
-}
-
 func TestFragmentation(t *testing.T) {
 	var manyPayloadViewsSizes [1000]int
 	for i := range manyPayloadViewsSizes {
@@ -228,7 +188,8 @@ func TestFragmentation(t *testing.T) {
 
 	for _, ft := range fragTests {
 		t.Run(ft.description, func(t *testing.T) {
-			r := buildTestRoute(t, channel.New(0, ft.mtu, ""), nil)
+			ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+			r := buildRoute(t, ep)
 			pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
 			source := pkt.Clone()
 			err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
@@ -237,16 +198,16 @@ func TestFragmentation(t *testing.T) {
 				TOS:      stack.DefaultTOS,
 			}, pkt)
 			if err != nil {
-				t.Errorf("err got %v, want %v", err, nil)
+				t.Errorf("got err = %s, want = nil", err)
 			}
 
-			if got, want := len(r.linkEP.WrittenPackets), ft.expectedFrags; got != want {
-				t.Errorf("len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
+			if got := len(ep.WrittenPackets); got != ft.expectedFrags {
+				t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags)
 			}
-			if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
-				t.Errorf("no errors yet len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
+			if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
+				t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want)
 			}
-			compareFragments(t, r.linkEP.WrittenPackets, source, ft.mtu)
+			compareFragments(t, ep.WrittenPackets, source, ft.mtu)
 		})
 	}
 }
@@ -259,35 +220,30 @@ func TestFragmentationErrors(t *testing.T) {
 		mtu                   uint32
 		transportHeaderLength int
 		payloadViewsSizes     []int
-		packetCollectorErrors []*tcpip.Error
+		err                   *tcpip.Error
+		allowPackets          int
 	}{
-		{"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
-		{"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
-		{"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
-		{"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+		{"NoFrag", 2000, 0, []int{1000}, tcpip.ErrAborted, 0},
+		{"ErrorOnFirstFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 0},
+		{"ErrorOnSecondFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 1},
+		{"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, tcpip.ErrAborted, 0},
 	}
 
 	for _, ft := range fragTests {
 		t.Run(ft.description, func(t *testing.T) {
-			r := buildTestRoute(t, channel.New(0, ft.mtu, ""), ft.packetCollectorErrors)
+			ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets)
+			r := buildRoute(t, ep)
 			pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
 			err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
 				Protocol: tcp.ProtocolNumber,
 				TTL:      42,
 				TOS:      stack.DefaultTOS,
 			}, pkt)
-			for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
-				if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
-					t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
-				}
-			}
-			// We only need to check that last error because all the ones before are
-			// nil.
-			if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
-				t.Errorf("err got %v, want %v", got, want)
+			if err != ft.err {
+				t.Errorf("got WritePacket() = %s, want = %s", err, ft.err)
 			}
-			if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
-				t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
+			if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want {
+				t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want)
 			}
 		})
 	}
@@ -1052,7 +1008,7 @@ func TestWriteStats(t *testing.T) {
 	tests := []struct {
 		name          string
 		setup         func(*testing.T, *stack.Stack)
-		linkEP        func() stack.LinkEndpoint
+		allowPackets  int
 		expectSent    int
 		expectDropped int
 		expectWritten int
@@ -1061,7 +1017,7 @@ func TestWriteStats(t *testing.T) {
 			name: "Accept all",
 			// No setup needed, tables accept everything by default.
 			setup:         func(*testing.T, *stack.Stack) {},
-			linkEP:        func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+			allowPackets:  math.MaxInt32,
 			expectSent:    nPackets,
 			expectDropped: 0,
 			expectWritten: nPackets,
@@ -1069,7 +1025,7 @@ func TestWriteStats(t *testing.T) {
 			name: "Accept all with error",
 			// No setup needed, tables accept everything by default.
 			setup:         func(*testing.T, *stack.Stack) {},
-			linkEP:        func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} },
+			allowPackets:  nPackets - 1,
 			expectSent:    nPackets - 1,
 			expectDropped: 0,
 			expectWritten: nPackets - 1,
@@ -1086,10 +1042,10 @@ func TestWriteStats(t *testing.T) {
 				ruleIdx := filter.BuiltinChains[stack.Output]
 				filter.Rules[ruleIdx].Target = stack.DropTarget{}
 				if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
-					t.Fatalf("failed to replace table: %v", err)
+					t.Fatalf("failed to replace table: %s", err)
 				}
 			},
-			linkEP:        func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+			allowPackets:  math.MaxInt32,
 			expectSent:    0,
 			expectDropped: nPackets,
 			expectWritten: nPackets,
@@ -1111,10 +1067,10 @@ func TestWriteStats(t *testing.T) {
 				// Make sure the next rule is ACCEPT.
 				filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{}
 				if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
-					t.Fatalf("failed to replace table: %v", err)
+					t.Fatalf("failed to replace table: %s", err)
 				}
 			},
-			linkEP:        func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+			allowPackets:  math.MaxInt32,
 			expectSent:    nPackets - 1,
 			expectDropped: 1,
 			expectWritten: nPackets,
@@ -1150,7 +1106,8 @@ func TestWriteStats(t *testing.T) {
 		t.Run(writer.name, func(t *testing.T) {
 			for _, test := range tests {
 				t.Run(test.name, func(t *testing.T) {
-					rt := buildRoute(t, nil, test.linkEP())
+					ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets)
+					rt := buildRoute(t, ep)
 
 					var pkts stack.PacketBufferList
 					for i := 0; i < nPackets; i++ {
@@ -1181,101 +1138,37 @@ func TestWriteStats(t *testing.T) {
 	}
 }
 
-func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
 	s := stack.New(stack.Options{
 		NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
 	})
-	s.CreateNIC(1, linkEP)
+	if err := s.CreateNIC(1, ep); err != nil {
+		t.Fatalf("CreateNIC(1, _) failed: %s", err)
+	}
 	const (
 		src = "\x10\x00\x00\x01"
 		dst = "\x10\x00\x00\x02"
 	)
-	s.AddAddress(1, ipv4.ProtocolNumber, src)
+	if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
+		t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err)
+	}
 	{
 		subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
 		if err != nil {
-			t.Fatal(err)
+			t.Fatalf("NewSubnet(_, _) failed: %v", err)
 		}
 		s.SetRouteTable([]tcpip.Route{{
 			Destination: subnet,
 			NIC:         1,
 		}})
 	}
-	rt, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
+	rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
 	if err != nil {
-		t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+		t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err)
 	}
 	return rt
 }
 
-// limitedEP is a link endpoint that writes up to a certain number of packets
-// before returning errors.
-type limitedEP struct {
-	limit int
-}
-
-// MTU implements LinkEndpoint.MTU.
-func (*limitedEP) MTU() uint32 {
-	// Give an MTU that won't cause fragmentation for IPv4+UDP.
-	return header.IPv4MinimumSize + header.UDPMinimumSize
-}
-
-// Capabilities implements LinkEndpoint.Capabilities.
-func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 }
-
-// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
-func (*limitedEP) MaxHeaderLength() uint16 { return 0 }
-
-// LinkAddress implements LinkEndpoint.LinkAddress.
-func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" }
-
-// WritePacket implements LinkEndpoint.WritePacket.
-func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
-	if ep.limit == 0 {
-		return tcpip.ErrInvalidEndpointState
-	}
-	ep.limit--
-	return nil
-}
-
-// WritePackets implements LinkEndpoint.WritePackets.
-func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
-	if ep.limit == 0 {
-		return 0, tcpip.ErrInvalidEndpointState
-	}
-	nWritten := ep.limit
-	if nWritten > pkts.Len() {
-		nWritten = pkts.Len()
-	}
-	ep.limit -= nWritten
-	return nWritten, nil
-}
-
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
-	if ep.limit == 0 {
-		return tcpip.ErrInvalidEndpointState
-	}
-	ep.limit--
-	return nil
-}
-
-// Attach implements LinkEndpoint.Attach.
-func (*limitedEP) Attach(_ stack.NetworkDispatcher) {}
-
-// IsAttached implements LinkEndpoint.IsAttached.
-func (*limitedEP) IsAttached() bool { return false }
-
-// Wait implements LinkEndpoint.Wait.
-func (*limitedEP) Wait() {}
-
-// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
-func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther }
-
-// AddHeader implements LinkEndpoint.AddHeader.
-func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
-}
-
 // limitedMatcher is an iptables matcher that matches after a certain number of
 // packets are checked against it.
 type limitedMatcher struct {
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index cd5fe3ea8..8bd8f5c52 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -35,6 +35,7 @@ go_test(
         "//pkg/tcpip/header",
         "//pkg/tcpip/link/channel",
         "//pkg/tcpip/link/sniffer",
+        "//pkg/tcpip/network/testutil",
         "//pkg/tcpip/stack",
         "//pkg/tcpip/transport/icmp",
         "//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 9eea1de8d..7d138dadb 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -23,6 +23,7 @@ import (
 	"gvisor.dev/gvisor/pkg/tcpip/buffer"
 	"gvisor.dev/gvisor/pkg/tcpip/header"
 	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+	"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
 	"gvisor.dev/gvisor/pkg/tcpip/stack"
 	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
 	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -1715,7 +1716,7 @@ func TestWriteStats(t *testing.T) {
 	tests := []struct {
 		name          string
 		setup         func(*testing.T, *stack.Stack)
-		linkEP        func() stack.LinkEndpoint
+		allowPackets  int
 		expectSent    int
 		expectDropped int
 		expectWritten int
@@ -1724,7 +1725,7 @@ func TestWriteStats(t *testing.T) {
 			name: "Accept all",
 			// No setup needed, tables accept everything by default.
 			setup:         func(*testing.T, *stack.Stack) {},
-			linkEP:        func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+			allowPackets:  math.MaxInt32,
 			expectSent:    nPackets,
 			expectDropped: 0,
 			expectWritten: nPackets,
@@ -1732,7 +1733,7 @@ func TestWriteStats(t *testing.T) {
 			name: "Accept all with error",
 			// No setup needed, tables accept everything by default.
 			setup:         func(*testing.T, *stack.Stack) {},
-			linkEP:        func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} },
+			allowPackets:  nPackets - 1,
 			expectSent:    nPackets - 1,
 			expectDropped: 0,
 			expectWritten: nPackets - 1,
@@ -1752,7 +1753,7 @@ func TestWriteStats(t *testing.T) {
 					t.Fatalf("failed to replace table: %v", err)
 				}
 			},
-			linkEP:        func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+			allowPackets:  math.MaxInt32,
 			expectSent:    0,
 			expectDropped: nPackets,
 			expectWritten: nPackets,
@@ -1777,7 +1778,7 @@ func TestWriteStats(t *testing.T) {
 					t.Fatalf("failed to replace table: %v", err)
 				}
 			},
-			linkEP:        func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+			allowPackets:  math.MaxInt32,
 			expectSent:    nPackets - 1,
 			expectDropped: 1,
 			expectWritten: nPackets,
@@ -1812,7 +1813,8 @@ func TestWriteStats(t *testing.T) {
 		t.Run(writer.name, func(t *testing.T) {
 			for _, test := range tests {
 				t.Run(test.name, func(t *testing.T) {
-					rt := buildRoute(t, nil, test.linkEP())
+					ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
+					rt := buildRoute(t, ep)
 
 					var pkts stack.PacketBufferList
 					for i := 0; i < nPackets; i++ {
@@ -1843,100 +1845,37 @@ func TestWriteStats(t *testing.T) {
 	}
 }
 
-func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
 	s := stack.New(stack.Options{
 		NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
 	})
-	s.CreateNIC(1, linkEP)
+	if err := s.CreateNIC(1, ep); err != nil {
+		t.Fatalf("CreateNIC(1, _) failed: %s", err)
+	}
 	const (
 		src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
 		dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
 	)
-	s.AddAddress(1, ProtocolNumber, src)
+	if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
+		t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err)
+	}
 	{
 		subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"))
 		if err != nil {
-			t.Fatal(err)
+			t.Fatalf("NewSubnet(_, _) failed: %v", err)
 		}
 		s.SetRouteTable([]tcpip.Route{{
 			Destination: subnet,
 			NIC:         1,
 		}})
 	}
-	rt, err := s.FindRoute(0, src, dst, ProtocolNumber, false /* multicastLoop */)
+	rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */)
 	if err != nil {
-		t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+		t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err)
 	}
 	return rt
 }
 
-// limitedEP is a link endpoint that writes up to a certain number of packets
-// before returning errors.
-type limitedEP struct {
-	limit int
-}
-
-// MTU implements LinkEndpoint.MTU.
-func (*limitedEP) MTU() uint32 {
-	return header.IPv6MinimumMTU
-}
-
-// Capabilities implements LinkEndpoint.Capabilities.
-func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 }
-
-// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
-func (*limitedEP) MaxHeaderLength() uint16 { return 0 }
-
-// LinkAddress implements LinkEndpoint.LinkAddress.
-func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" }
-
-// WritePacket implements LinkEndpoint.WritePacket.
-func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
-	if ep.limit == 0 {
-		return tcpip.ErrInvalidEndpointState
-	}
-	ep.limit--
-	return nil
-}
-
-// WritePackets implements LinkEndpoint.WritePackets.
-func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
-	if ep.limit == 0 {
-		return 0, tcpip.ErrInvalidEndpointState
-	}
-	nWritten := ep.limit
-	if nWritten > pkts.Len() {
-		nWritten = pkts.Len()
-	}
-	ep.limit -= nWritten
-	return nWritten, nil
-}
-
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
-	if ep.limit == 0 {
-		return tcpip.ErrInvalidEndpointState
-	}
-	ep.limit--
-	return nil
-}
-
-// Attach implements LinkEndpoint.Attach.
-func (*limitedEP) Attach(_ stack.NetworkDispatcher) {}
-
-// IsAttached implements LinkEndpoint.IsAttached.
-func (*limitedEP) IsAttached() bool { return false }
-
-// Wait implements LinkEndpoint.Wait.
-func (*limitedEP) Wait() {}
-
-// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
-func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther }
-
-// AddHeader implements LinkEndpoint.AddHeader.
-func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
-}
-
 // limitedMatcher is an iptables matcher that matches after a certain number of
 // packets are checked against it.
 type limitedMatcher struct {
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
index e218563d0..c9e57dc0d 100644
--- a/pkg/tcpip/network/testutil/BUILD
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -7,11 +7,14 @@ go_library(
     srcs = [
         "testutil.go",
     ],
-    visibility = ["//pkg/tcpip/network/ipv4:__pkg__"],
+    visibility = [
+        "//pkg/tcpip/network/ipv4:__pkg__",
+        "//pkg/tcpip/network/ipv6:__pkg__",
+    ],
     deps = [
         "//pkg/tcpip",
         "//pkg/tcpip/buffer",
-        "//pkg/tcpip/link/channel",
+        "//pkg/tcpip/header",
         "//pkg/tcpip/stack",
     ],
 )
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
index bf5ce74be..7cc52985e 100644
--- a/pkg/tcpip/network/testutil/testutil.go
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -22,48 +22,100 @@ import (
 
 	"gvisor.dev/gvisor/pkg/tcpip"
 	"gvisor.dev/gvisor/pkg/tcpip/buffer"
-	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+	"gvisor.dev/gvisor/pkg/tcpip/header"
 	"gvisor.dev/gvisor/pkg/tcpip/stack"
 )
 
-// TestEndpoint is an endpoint used for testing, it stores packets written to it
-// and can mock errors.
-type TestEndpoint struct {
-	*channel.Endpoint
-
-	// WrittenPackets is where we store packets written via WritePacket().
+// MockLinkEndpoint is an endpoint used for testing, it stores packets written
+// to it and can mock errors.
+type MockLinkEndpoint struct {
+	// WrittenPackets is where packets written to the endpoint are stored.
 	WrittenPackets []*stack.PacketBuffer
 
-	packetCollectorErrors []*tcpip.Error
+	mtu          uint32
+	err          *tcpip.Error
+	allowPackets int
 }
 
-// NewTestEndpoint creates a new TestEndpoint endpoint.
+// NewMockLinkEndpoint creates a new MockLinkEndpoint.
 //
-// packetCollectorErrors can be used to set error values and each call to
-// WritePacket will remove the first one from the slice and return it until
-// the slice is empty - at that point it will return nil every time.
-func NewTestEndpoint(ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) *TestEndpoint {
-	return &TestEndpoint{
-		Endpoint:              ep,
-		WrittenPackets:        make([]*stack.PacketBuffer, 0),
-		packetCollectorErrors: packetCollectorErrors,
+// err is the error that will be returned once allowPackets packets are written
+// to the endpoint.
+func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint {
+	return &MockLinkEndpoint{
+		mtu:          mtu,
+		err:          err,
+		allowPackets: allowPackets,
+	}
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu }
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 }
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 }
+
+// LinkAddress implements LinkEndpoint.LinkAddress.
+func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (ep *MockLinkEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+	if ep.allowPackets == 0 {
+		return ep.err
 	}
+	ep.allowPackets--
+	ep.WrittenPackets = append(ep.WrittenPackets, pkt)
+	return nil
 }
 
-// WritePacket stores outbound packets and may return an error if one was
-// injected.
-func (e *TestEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
-	e.WrittenPackets = append(e.WrittenPackets, pkt)
+// WritePackets implements LinkEndpoint.WritePackets.
+func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+	var n int
 
-	if len(e.packetCollectorErrors) > 0 {
-		nextError := e.packetCollectorErrors[0]
-		e.packetCollectorErrors = e.packetCollectorErrors[1:]
-		return nextError
+	for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+		if err := ep.WritePacket(r, gso, protocol, pkt); err != nil {
+			return n, err
+		}
+		n++
 	}
 
+	return n, nil
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+	if ep.allowPackets == 0 {
+		return ep.err
+	}
+	ep.allowPackets--
+
+	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+		Data: vv,
+	})
+	ep.WrittenPackets = append(ep.WrittenPackets, pkt)
+
 	return nil
 }
 
+// Attach implements LinkEndpoint.Attach.
+func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (*MockLinkEndpoint) IsAttached() bool { return false }
+
+// Wait implements LinkEndpoint.Wait.
+func (*MockLinkEndpoint) Wait() {}
+
+// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
+func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
+
+// AddHeader implements LinkEndpoint.AddHeader.
+func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
+
 // MakeRandPkt generates a randomized packet. transportHeaderLength indicates
 // how many random bytes will be copied in the Transport Header.
 // extraHeaderReserveLength indicates how much extra space will be reserved for
-- 
cgit v1.2.3