summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorArthur Sfez <asfez@google.com>2020-09-22 15:04:11 -0700
committergVisor bot <gvisor-bot@google.com>2020-09-22 15:06:16 -0700
commitcf3cef1171bdfb41a27d563eb368d4488e0b99f1 (patch)
tree734bce9fbb5016aa5696cbaa9cfc565a2eed8a45
parent20dc83c9ecde1c4e99e10023c79008420fa0601f (diff)
Refactor testutil.TestEndpoint and use it instead of limitedEP
The new testutil.MockLinkEndpoint implementation is not composed by channel.Channel anymore because none of its features were used. PiperOrigin-RevId: 333167753
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go187
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go97
-rw-r--r--pkg/tcpip/network/testutil/BUILD7
-rw-r--r--pkg/tcpip/network/testutil/testutil.go102
5 files changed, 141 insertions, 253 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index b14bc98e8..86187aba8 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -17,6 +17,7 @@ package ipv4_test
import (
"bytes"
"encoding/hex"
+ "math"
"testing"
"github.com/google/go-cmp/cmp"
@@ -160,47 +161,6 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
}
}
-type testRoute struct {
- stack.Route
-
- linkEP *testutil.TestEndpoint
-}
-
-func buildTestRoute(t *testing.T, ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) testRoute {
- // Make the packet and write it.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- })
- testEP := testutil.NewTestEndpoint(ep, packetCollectorErrors)
- s.CreateNIC(1, testEP)
- const (
- src = "\x10\x00\x00\x01"
- dst = "\x10\x00\x00\x02"
- )
- s.AddAddress(1, ipv4.ProtocolNumber, src)
- {
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }})
- }
- r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("s.FindRoute got %v, want %v", err, nil)
- }
- t.Cleanup(func() {
- testEP.Close()
- })
- return testRoute{
- Route: r,
- linkEP: testEP,
- }
-}
-
func TestFragmentation(t *testing.T) {
var manyPayloadViewsSizes [1000]int
for i := range manyPayloadViewsSizes {
@@ -228,7 +188,8 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- r := buildTestRoute(t, channel.New(0, ft.mtu, ""), nil)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
source := pkt.Clone()
err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
@@ -237,16 +198,16 @@ func TestFragmentation(t *testing.T) {
TOS: stack.DefaultTOS,
}, pkt)
if err != nil {
- t.Errorf("err got %v, want %v", err, nil)
+ t.Errorf("got err = %s, want = nil", err)
}
- if got, want := len(r.linkEP.WrittenPackets), ft.expectedFrags; got != want {
- t.Errorf("len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
+ if got := len(ep.WrittenPackets); got != ft.expectedFrags {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags)
}
- if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
- t.Errorf("no errors yet len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
+ if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
+ t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
- compareFragments(t, r.linkEP.WrittenPackets, source, ft.mtu)
+ compareFragments(t, ep.WrittenPackets, source, ft.mtu)
})
}
}
@@ -259,35 +220,30 @@ func TestFragmentationErrors(t *testing.T) {
mtu uint32
transportHeaderLength int
payloadViewsSizes []int
- packetCollectorErrors []*tcpip.Error
+ err *tcpip.Error
+ allowPackets int
}{
- {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
- {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
- {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
- {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+ {"NoFrag", 2000, 0, []int{1000}, tcpip.ErrAborted, 0},
+ {"ErrorOnFirstFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 0},
+ {"ErrorOnSecondFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 1},
+ {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, tcpip.ErrAborted, 0},
}
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- r := buildTestRoute(t, channel.New(0, ft.mtu, ""), ft.packetCollectorErrors)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets)
+ r := buildRoute(t, ep)
pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: 42,
TOS: stack.DefaultTOS,
}, pkt)
- for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
- if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
- t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
- }
- }
- // We only need to check that last error because all the ones before are
- // nil.
- if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
- t.Errorf("err got %v, want %v", got, want)
+ if err != ft.err {
+ t.Errorf("got WritePacket() = %s, want = %s", err, ft.err)
}
- if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
- t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
+ if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
})
}
@@ -1052,7 +1008,7 @@ func TestWriteStats(t *testing.T) {
tests := []struct {
name string
setup func(*testing.T, *stack.Stack)
- linkEP func() stack.LinkEndpoint
+ allowPackets int
expectSent int
expectDropped int
expectWritten int
@@ -1061,7 +1017,7 @@ func TestWriteStats(t *testing.T) {
name: "Accept all",
// No setup needed, tables accept everything by default.
setup: func(*testing.T, *stack.Stack) {},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ allowPackets: math.MaxInt32,
expectSent: nPackets,
expectDropped: 0,
expectWritten: nPackets,
@@ -1069,7 +1025,7 @@ func TestWriteStats(t *testing.T) {
name: "Accept all with error",
// No setup needed, tables accept everything by default.
setup: func(*testing.T, *stack.Stack) {},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} },
+ allowPackets: nPackets - 1,
expectSent: nPackets - 1,
expectDropped: 0,
expectWritten: nPackets - 1,
@@ -1086,10 +1042,10 @@ func TestWriteStats(t *testing.T) {
ruleIdx := filter.BuiltinChains[stack.Output]
filter.Rules[ruleIdx].Target = stack.DropTarget{}
if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %v", err)
+ t.Fatalf("failed to replace table: %s", err)
}
},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ allowPackets: math.MaxInt32,
expectSent: 0,
expectDropped: nPackets,
expectWritten: nPackets,
@@ -1111,10 +1067,10 @@ func TestWriteStats(t *testing.T) {
// Make sure the next rule is ACCEPT.
filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %v", err)
+ t.Fatalf("failed to replace table: %s", err)
}
},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ allowPackets: math.MaxInt32,
expectSent: nPackets - 1,
expectDropped: 1,
expectWritten: nPackets,
@@ -1150,7 +1106,8 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- rt := buildRoute(t, nil, test.linkEP())
+ ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
@@ -1181,101 +1138,37 @@ func TestWriteStats(t *testing.T) {
}
}
-func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
})
- s.CreateNIC(1, linkEP)
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC(1, _) failed: %s", err)
+ }
const (
src = "\x10\x00\x00\x01"
dst = "\x10\x00\x00\x02"
)
- s.AddAddress(1, ipv4.ProtocolNumber, src)
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
+ t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err)
+ }
{
subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
if err != nil {
- t.Fatal(err)
+ t.Fatalf("NewSubnet(_, _) failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: subnet,
NIC: 1,
}})
}
- rt, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
+ rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+ t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err)
}
return rt
}
-// limitedEP is a link endpoint that writes up to a certain number of packets
-// before returning errors.
-type limitedEP struct {
- limit int
-}
-
-// MTU implements LinkEndpoint.MTU.
-func (*limitedEP) MTU() uint32 {
- // Give an MTU that won't cause fragmentation for IPv4+UDP.
- return header.IPv4MinimumSize + header.UDPMinimumSize
-}
-
-// Capabilities implements LinkEndpoint.Capabilities.
-func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 }
-
-// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
-func (*limitedEP) MaxHeaderLength() uint16 { return 0 }
-
-// LinkAddress implements LinkEndpoint.LinkAddress.
-func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" }
-
-// WritePacket implements LinkEndpoint.WritePacket.
-func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
- if ep.limit == 0 {
- return tcpip.ErrInvalidEndpointState
- }
- ep.limit--
- return nil
-}
-
-// WritePackets implements LinkEndpoint.WritePackets.
-func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- if ep.limit == 0 {
- return 0, tcpip.ErrInvalidEndpointState
- }
- nWritten := ep.limit
- if nWritten > pkts.Len() {
- nWritten = pkts.Len()
- }
- ep.limit -= nWritten
- return nWritten, nil
-}
-
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
- if ep.limit == 0 {
- return tcpip.ErrInvalidEndpointState
- }
- ep.limit--
- return nil
-}
-
-// Attach implements LinkEndpoint.Attach.
-func (*limitedEP) Attach(_ stack.NetworkDispatcher) {}
-
-// IsAttached implements LinkEndpoint.IsAttached.
-func (*limitedEP) IsAttached() bool { return false }
-
-// Wait implements LinkEndpoint.Wait.
-func (*limitedEP) Wait() {}
-
-// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
-func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther }
-
-// AddHeader implements LinkEndpoint.AddHeader.
-func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
-}
-
// limitedMatcher is an iptables matcher that matches after a certain number of
// packets are checked against it.
type limitedMatcher struct {
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index cd5fe3ea8..8bd8f5c52 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -35,6 +35,7 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 9eea1de8d..7d138dadb 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -1715,7 +1716,7 @@ func TestWriteStats(t *testing.T) {
tests := []struct {
name string
setup func(*testing.T, *stack.Stack)
- linkEP func() stack.LinkEndpoint
+ allowPackets int
expectSent int
expectDropped int
expectWritten int
@@ -1724,7 +1725,7 @@ func TestWriteStats(t *testing.T) {
name: "Accept all",
// No setup needed, tables accept everything by default.
setup: func(*testing.T, *stack.Stack) {},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ allowPackets: math.MaxInt32,
expectSent: nPackets,
expectDropped: 0,
expectWritten: nPackets,
@@ -1732,7 +1733,7 @@ func TestWriteStats(t *testing.T) {
name: "Accept all with error",
// No setup needed, tables accept everything by default.
setup: func(*testing.T, *stack.Stack) {},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} },
+ allowPackets: nPackets - 1,
expectSent: nPackets - 1,
expectDropped: 0,
expectWritten: nPackets - 1,
@@ -1752,7 +1753,7 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %v", err)
}
},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ allowPackets: math.MaxInt32,
expectSent: 0,
expectDropped: nPackets,
expectWritten: nPackets,
@@ -1777,7 +1778,7 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %v", err)
}
},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ allowPackets: math.MaxInt32,
expectSent: nPackets - 1,
expectDropped: 1,
expectWritten: nPackets,
@@ -1812,7 +1813,8 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- rt := buildRoute(t, nil, test.linkEP())
+ ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
@@ -1843,100 +1845,37 @@ func TestWriteStats(t *testing.T) {
}
}
-func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
})
- s.CreateNIC(1, linkEP)
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC(1, _) failed: %s", err)
+ }
const (
src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
)
- s.AddAddress(1, ProtocolNumber, src)
+ if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
+ t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err)
+ }
{
subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"))
if err != nil {
- t.Fatal(err)
+ t.Fatalf("NewSubnet(_, _) failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: subnet,
NIC: 1,
}})
}
- rt, err := s.FindRoute(0, src, dst, ProtocolNumber, false /* multicastLoop */)
+ rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+ t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err)
}
return rt
}
-// limitedEP is a link endpoint that writes up to a certain number of packets
-// before returning errors.
-type limitedEP struct {
- limit int
-}
-
-// MTU implements LinkEndpoint.MTU.
-func (*limitedEP) MTU() uint32 {
- return header.IPv6MinimumMTU
-}
-
-// Capabilities implements LinkEndpoint.Capabilities.
-func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 }
-
-// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
-func (*limitedEP) MaxHeaderLength() uint16 { return 0 }
-
-// LinkAddress implements LinkEndpoint.LinkAddress.
-func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" }
-
-// WritePacket implements LinkEndpoint.WritePacket.
-func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
- if ep.limit == 0 {
- return tcpip.ErrInvalidEndpointState
- }
- ep.limit--
- return nil
-}
-
-// WritePackets implements LinkEndpoint.WritePackets.
-func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- if ep.limit == 0 {
- return 0, tcpip.ErrInvalidEndpointState
- }
- nWritten := ep.limit
- if nWritten > pkts.Len() {
- nWritten = pkts.Len()
- }
- ep.limit -= nWritten
- return nWritten, nil
-}
-
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
- if ep.limit == 0 {
- return tcpip.ErrInvalidEndpointState
- }
- ep.limit--
- return nil
-}
-
-// Attach implements LinkEndpoint.Attach.
-func (*limitedEP) Attach(_ stack.NetworkDispatcher) {}
-
-// IsAttached implements LinkEndpoint.IsAttached.
-func (*limitedEP) IsAttached() bool { return false }
-
-// Wait implements LinkEndpoint.Wait.
-func (*limitedEP) Wait() {}
-
-// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
-func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther }
-
-// AddHeader implements LinkEndpoint.AddHeader.
-func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
-}
-
// limitedMatcher is an iptables matcher that matches after a certain number of
// packets are checked against it.
type limitedMatcher struct {
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
index e218563d0..c9e57dc0d 100644
--- a/pkg/tcpip/network/testutil/BUILD
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -7,11 +7,14 @@ go_library(
srcs = [
"testutil.go",
],
- visibility = ["//pkg/tcpip/network/ipv4:__pkg__"],
+ visibility = [
+ "//pkg/tcpip/network/ipv4:__pkg__",
+ "//pkg/tcpip/network/ipv6:__pkg__",
+ ],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
- "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
index bf5ce74be..7cc52985e 100644
--- a/pkg/tcpip/network/testutil/testutil.go
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -22,48 +22,100 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// TestEndpoint is an endpoint used for testing, it stores packets written to it
-// and can mock errors.
-type TestEndpoint struct {
- *channel.Endpoint
-
- // WrittenPackets is where we store packets written via WritePacket().
+// MockLinkEndpoint is an endpoint used for testing, it stores packets written
+// to it and can mock errors.
+type MockLinkEndpoint struct {
+ // WrittenPackets is where packets written to the endpoint are stored.
WrittenPackets []*stack.PacketBuffer
- packetCollectorErrors []*tcpip.Error
+ mtu uint32
+ err *tcpip.Error
+ allowPackets int
}
-// NewTestEndpoint creates a new TestEndpoint endpoint.
+// NewMockLinkEndpoint creates a new MockLinkEndpoint.
//
-// packetCollectorErrors can be used to set error values and each call to
-// WritePacket will remove the first one from the slice and return it until
-// the slice is empty - at that point it will return nil every time.
-func NewTestEndpoint(ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) *TestEndpoint {
- return &TestEndpoint{
- Endpoint: ep,
- WrittenPackets: make([]*stack.PacketBuffer, 0),
- packetCollectorErrors: packetCollectorErrors,
+// err is the error that will be returned once allowPackets packets are written
+// to the endpoint.
+func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint {
+ return &MockLinkEndpoint{
+ mtu: mtu,
+ err: err,
+ allowPackets: allowPackets,
+ }
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu }
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 }
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 }
+
+// LinkAddress implements LinkEndpoint.LinkAddress.
+func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (ep *MockLinkEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if ep.allowPackets == 0 {
+ return ep.err
}
+ ep.allowPackets--
+ ep.WrittenPackets = append(ep.WrittenPackets, pkt)
+ return nil
}
-// WritePacket stores outbound packets and may return an error if one was
-// injected.
-func (e *TestEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.WrittenPackets = append(e.WrittenPackets, pkt)
+// WritePackets implements LinkEndpoint.WritePackets.
+func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ var n int
- if len(e.packetCollectorErrors) > 0 {
- nextError := e.packetCollectorErrors[0]
- e.packetCollectorErrors = e.packetCollectorErrors[1:]
- return nextError
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := ep.WritePacket(r, gso, protocol, pkt); err != nil {
+ return n, err
+ }
+ n++
}
+ return n, nil
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ if ep.allowPackets == 0 {
+ return ep.err
+ }
+ ep.allowPackets--
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
+ ep.WrittenPackets = append(ep.WrittenPackets, pkt)
+
return nil
}
+// Attach implements LinkEndpoint.Attach.
+func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (*MockLinkEndpoint) IsAttached() bool { return false }
+
+// Wait implements LinkEndpoint.Wait.
+func (*MockLinkEndpoint) Wait() {}
+
+// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
+func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
+
+// AddHeader implements LinkEndpoint.AddHeader.
+func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
+
// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
// how many random bytes will be copied in the Transport Header.
// extraHeaderReserveLength indicates how much extra space will be reserved for