diff options
author | Arthur Sfez <asfez@google.com> | 2020-09-22 15:04:11 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-09-22 15:06:16 -0700 |
commit | cf3cef1171bdfb41a27d563eb368d4488e0b99f1 (patch) | |
tree | 734bce9fbb5016aa5696cbaa9cfc565a2eed8a45 | |
parent | 20dc83c9ecde1c4e99e10023c79008420fa0601f (diff) |
Refactor testutil.TestEndpoint and use it instead of limitedEP
The new testutil.MockLinkEndpoint implementation is not composed by
channel.Channel anymore because none of its features were used.
PiperOrigin-RevId: 333167753
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 187 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6_test.go | 97 | ||||
-rw-r--r-- | pkg/tcpip/network/testutil/BUILD | 7 | ||||
-rw-r--r-- | pkg/tcpip/network/testutil/testutil.go | 102 |
5 files changed, 141 insertions, 253 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index b14bc98e8..86187aba8 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -17,6 +17,7 @@ package ipv4_test import ( "bytes" "encoding/hex" + "math" "testing" "github.com/google/go-cmp/cmp" @@ -160,47 +161,6 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI } } -type testRoute struct { - stack.Route - - linkEP *testutil.TestEndpoint -} - -func buildTestRoute(t *testing.T, ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) testRoute { - // Make the packet and write it. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - }) - testEP := testutil.NewTestEndpoint(ep, packetCollectorErrors) - s.CreateNIC(1, testEP) - const ( - src = "\x10\x00\x00\x01" - dst = "\x10\x00\x00\x02" - ) - s.AddAddress(1, ipv4.ProtocolNumber, src) - { - subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}) - } - r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute got %v, want %v", err, nil) - } - t.Cleanup(func() { - testEP.Close() - }) - return testRoute{ - Route: r, - linkEP: testEP, - } -} - func TestFragmentation(t *testing.T) { var manyPayloadViewsSizes [1000]int for i := range manyPayloadViewsSizes { @@ -228,7 +188,8 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - r := buildTestRoute(t, channel.New(0, ft.mtu, ""), nil) + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber) source := pkt.Clone() err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ @@ -237,16 +198,16 @@ func TestFragmentation(t *testing.T) { TOS: stack.DefaultTOS, }, pkt) if err != nil { - t.Errorf("err got %v, want %v", err, nil) + t.Errorf("got err = %s, want = nil", err) } - if got, want := len(r.linkEP.WrittenPackets), ft.expectedFrags; got != want { - t.Errorf("len(r.linkEP.WrittenPackets) got %d, want %d", got, want) + if got := len(ep.WrittenPackets); got != ft.expectedFrags { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags) } - if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { - t.Errorf("no errors yet len(r.linkEP.WrittenPackets) got %d, want %d", got, want) + if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { + t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want) } - compareFragments(t, r.linkEP.WrittenPackets, source, ft.mtu) + compareFragments(t, ep.WrittenPackets, source, ft.mtu) }) } } @@ -259,35 +220,30 @@ func TestFragmentationErrors(t *testing.T) { mtu uint32 transportHeaderLength int payloadViewsSizes []int - packetCollectorErrors []*tcpip.Error + err *tcpip.Error + allowPackets int }{ - {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}}, - {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, + {"NoFrag", 2000, 0, []int{1000}, tcpip.ErrAborted, 0}, + {"ErrorOnFirstFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 0}, + {"ErrorOnSecondFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 1}, + {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, tcpip.ErrAborted, 0}, } for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - r := buildTestRoute(t, channel.New(0, ft.mtu, ""), ft.packetCollectorErrors) + ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets) + r := buildRoute(t, ep) pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, }, pkt) - for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { - if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { - t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) - } - } - // We only need to check that last error because all the ones before are - // nil. - if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want { - t.Errorf("err got %v, want %v", got, want) + if err != ft.err { + t.Errorf("got WritePacket() = %s, want = %s", err, ft.err) } - if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { - t.Errorf("after linkEP error len(result) got %d, want %d", got, want) + if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want) } }) } @@ -1052,7 +1008,7 @@ func TestWriteStats(t *testing.T) { tests := []struct { name string setup func(*testing.T, *stack.Stack) - linkEP func() stack.LinkEndpoint + allowPackets int expectSent int expectDropped int expectWritten int @@ -1061,7 +1017,7 @@ func TestWriteStats(t *testing.T) { name: "Accept all", // No setup needed, tables accept everything by default. setup: func(*testing.T, *stack.Stack) {}, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + allowPackets: math.MaxInt32, expectSent: nPackets, expectDropped: 0, expectWritten: nPackets, @@ -1069,7 +1025,7 @@ func TestWriteStats(t *testing.T) { name: "Accept all with error", // No setup needed, tables accept everything by default. setup: func(*testing.T, *stack.Stack) {}, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} }, + allowPackets: nPackets - 1, expectSent: nPackets - 1, expectDropped: 0, expectWritten: nPackets - 1, @@ -1086,10 +1042,10 @@ func TestWriteStats(t *testing.T) { ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = stack.DropTarget{} if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %v", err) + t.Fatalf("failed to replace table: %s", err) } }, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + allowPackets: math.MaxInt32, expectSent: 0, expectDropped: nPackets, expectWritten: nPackets, @@ -1111,10 +1067,10 @@ func TestWriteStats(t *testing.T) { // Make sure the next rule is ACCEPT. filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %v", err) + t.Fatalf("failed to replace table: %s", err) } }, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + allowPackets: math.MaxInt32, expectSent: nPackets - 1, expectDropped: 1, expectWritten: nPackets, @@ -1150,7 +1106,8 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - rt := buildRoute(t, nil, test.linkEP()) + ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets) + rt := buildRoute(t, ep) var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { @@ -1181,101 +1138,37 @@ func TestWriteStats(t *testing.T) { } } -func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route { +func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, }) - s.CreateNIC(1, linkEP) + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC(1, _) failed: %s", err) + } const ( src = "\x10\x00\x00\x01" dst = "\x10\x00\x00\x02" ) - s.AddAddress(1, ipv4.ProtocolNumber, src) + if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { + t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err) + } { subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) if err != nil { - t.Fatal(err) + t.Fatalf("NewSubnet(_, _) failed: %v", err) } s.SetRouteTable([]tcpip.Route{{ Destination: subnet, NIC: 1, }}) } - rt, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) + rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("s.FindRoute got %v, want %v", err, nil) + t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err) } return rt } -// limitedEP is a link endpoint that writes up to a certain number of packets -// before returning errors. -type limitedEP struct { - limit int -} - -// MTU implements LinkEndpoint.MTU. -func (*limitedEP) MTU() uint32 { - // Give an MTU that won't cause fragmentation for IPv4+UDP. - return header.IPv4MinimumSize + header.UDPMinimumSize -} - -// Capabilities implements LinkEndpoint.Capabilities. -func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 } - -// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. -func (*limitedEP) MaxHeaderLength() uint16 { return 0 } - -// LinkAddress implements LinkEndpoint.LinkAddress. -func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" } - -// WritePacket implements LinkEndpoint.WritePacket. -func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { - if ep.limit == 0 { - return tcpip.ErrInvalidEndpointState - } - ep.limit-- - return nil -} - -// WritePackets implements LinkEndpoint.WritePackets. -func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - if ep.limit == 0 { - return 0, tcpip.ErrInvalidEndpointState - } - nWritten := ep.limit - if nWritten > pkts.Len() { - nWritten = pkts.Len() - } - ep.limit -= nWritten - return nWritten, nil -} - -// WriteRawPacket implements LinkEndpoint.WriteRawPacket. -func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { - if ep.limit == 0 { - return tcpip.ErrInvalidEndpointState - } - ep.limit-- - return nil -} - -// Attach implements LinkEndpoint.Attach. -func (*limitedEP) Attach(_ stack.NetworkDispatcher) {} - -// IsAttached implements LinkEndpoint.IsAttached. -func (*limitedEP) IsAttached() bool { return false } - -// Wait implements LinkEndpoint.Wait. -func (*limitedEP) Wait() {} - -// ARPHardwareType implements LinkEndpoint.ARPHardwareType. -func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther } - -// AddHeader implements LinkEndpoint.AddHeader. -func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { -} - // limitedMatcher is an iptables matcher that matches after a certain number of // packets are checked against it. type limitedMatcher struct { diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index cd5fe3ea8..8bd8f5c52 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -35,6 +35,7 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 9eea1de8d..7d138dadb 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -1715,7 +1716,7 @@ func TestWriteStats(t *testing.T) { tests := []struct { name string setup func(*testing.T, *stack.Stack) - linkEP func() stack.LinkEndpoint + allowPackets int expectSent int expectDropped int expectWritten int @@ -1724,7 +1725,7 @@ func TestWriteStats(t *testing.T) { name: "Accept all", // No setup needed, tables accept everything by default. setup: func(*testing.T, *stack.Stack) {}, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + allowPackets: math.MaxInt32, expectSent: nPackets, expectDropped: 0, expectWritten: nPackets, @@ -1732,7 +1733,7 @@ func TestWriteStats(t *testing.T) { name: "Accept all with error", // No setup needed, tables accept everything by default. setup: func(*testing.T, *stack.Stack) {}, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} }, + allowPackets: nPackets - 1, expectSent: nPackets - 1, expectDropped: 0, expectWritten: nPackets - 1, @@ -1752,7 +1753,7 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %v", err) } }, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + allowPackets: math.MaxInt32, expectSent: 0, expectDropped: nPackets, expectWritten: nPackets, @@ -1777,7 +1778,7 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %v", err) } }, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + allowPackets: math.MaxInt32, expectSent: nPackets - 1, expectDropped: 1, expectWritten: nPackets, @@ -1812,7 +1813,8 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - rt := buildRoute(t, nil, test.linkEP()) + ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) + rt := buildRoute(t, ep) var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { @@ -1843,100 +1845,37 @@ func TestWriteStats(t *testing.T) { } } -func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route { +func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, }) - s.CreateNIC(1, linkEP) + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC(1, _) failed: %s", err) + } const ( src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" ) - s.AddAddress(1, ProtocolNumber, src) + if err := s.AddAddress(1, ProtocolNumber, src); err != nil { + t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err) + } { subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")) if err != nil { - t.Fatal(err) + t.Fatalf("NewSubnet(_, _) failed: %v", err) } s.SetRouteTable([]tcpip.Route{{ Destination: subnet, NIC: 1, }}) } - rt, err := s.FindRoute(0, src, dst, ProtocolNumber, false /* multicastLoop */) + rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("s.FindRoute got %v, want %v", err, nil) + t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err) } return rt } -// limitedEP is a link endpoint that writes up to a certain number of packets -// before returning errors. -type limitedEP struct { - limit int -} - -// MTU implements LinkEndpoint.MTU. -func (*limitedEP) MTU() uint32 { - return header.IPv6MinimumMTU -} - -// Capabilities implements LinkEndpoint.Capabilities. -func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 } - -// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. -func (*limitedEP) MaxHeaderLength() uint16 { return 0 } - -// LinkAddress implements LinkEndpoint.LinkAddress. -func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" } - -// WritePacket implements LinkEndpoint.WritePacket. -func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { - if ep.limit == 0 { - return tcpip.ErrInvalidEndpointState - } - ep.limit-- - return nil -} - -// WritePackets implements LinkEndpoint.WritePackets. -func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - if ep.limit == 0 { - return 0, tcpip.ErrInvalidEndpointState - } - nWritten := ep.limit - if nWritten > pkts.Len() { - nWritten = pkts.Len() - } - ep.limit -= nWritten - return nWritten, nil -} - -// WriteRawPacket implements LinkEndpoint.WriteRawPacket. -func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { - if ep.limit == 0 { - return tcpip.ErrInvalidEndpointState - } - ep.limit-- - return nil -} - -// Attach implements LinkEndpoint.Attach. -func (*limitedEP) Attach(_ stack.NetworkDispatcher) {} - -// IsAttached implements LinkEndpoint.IsAttached. -func (*limitedEP) IsAttached() bool { return false } - -// Wait implements LinkEndpoint.Wait. -func (*limitedEP) Wait() {} - -// ARPHardwareType implements LinkEndpoint.ARPHardwareType. -func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther } - -// AddHeader implements LinkEndpoint.AddHeader. -func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { -} - // limitedMatcher is an iptables matcher that matches after a certain number of // packets are checked against it. type limitedMatcher struct { diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD index e218563d0..c9e57dc0d 100644 --- a/pkg/tcpip/network/testutil/BUILD +++ b/pkg/tcpip/network/testutil/BUILD @@ -7,11 +7,14 @@ go_library( srcs = [ "testutil.go", ], - visibility = ["//pkg/tcpip/network/ipv4:__pkg__"], + visibility = [ + "//pkg/tcpip/network/ipv4:__pkg__", + "//pkg/tcpip/network/ipv6:__pkg__", + ], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", - "//pkg/tcpip/link/channel", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go index bf5ce74be..7cc52985e 100644 --- a/pkg/tcpip/network/testutil/testutil.go +++ b/pkg/tcpip/network/testutil/testutil.go @@ -22,48 +22,100 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// TestEndpoint is an endpoint used for testing, it stores packets written to it -// and can mock errors. -type TestEndpoint struct { - *channel.Endpoint - - // WrittenPackets is where we store packets written via WritePacket(). +// MockLinkEndpoint is an endpoint used for testing, it stores packets written +// to it and can mock errors. +type MockLinkEndpoint struct { + // WrittenPackets is where packets written to the endpoint are stored. WrittenPackets []*stack.PacketBuffer - packetCollectorErrors []*tcpip.Error + mtu uint32 + err *tcpip.Error + allowPackets int } -// NewTestEndpoint creates a new TestEndpoint endpoint. +// NewMockLinkEndpoint creates a new MockLinkEndpoint. // -// packetCollectorErrors can be used to set error values and each call to -// WritePacket will remove the first one from the slice and return it until -// the slice is empty - at that point it will return nil every time. -func NewTestEndpoint(ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) *TestEndpoint { - return &TestEndpoint{ - Endpoint: ep, - WrittenPackets: make([]*stack.PacketBuffer, 0), - packetCollectorErrors: packetCollectorErrors, +// err is the error that will be returned once allowPackets packets are written +// to the endpoint. +func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint { + return &MockLinkEndpoint{ + mtu: mtu, + err: err, + allowPackets: allowPackets, + } +} + +// MTU implements LinkEndpoint.MTU. +func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu } + +// Capabilities implements LinkEndpoint.Capabilities. +func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 } + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } + +// LinkAddress implements LinkEndpoint.LinkAddress. +func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } + +// WritePacket implements LinkEndpoint.WritePacket. +func (ep *MockLinkEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if ep.allowPackets == 0 { + return ep.err } + ep.allowPackets-- + ep.WrittenPackets = append(ep.WrittenPackets, pkt) + return nil } -// WritePacket stores outbound packets and may return an error if one was -// injected. -func (e *TestEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - e.WrittenPackets = append(e.WrittenPackets, pkt) +// WritePackets implements LinkEndpoint.WritePackets. +func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + var n int - if len(e.packetCollectorErrors) > 0 { - nextError := e.packetCollectorErrors[0] - e.packetCollectorErrors = e.packetCollectorErrors[1:] - return nextError + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err := ep.WritePacket(r, gso, protocol, pkt); err != nil { + return n, err + } + n++ } + return n, nil +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + if ep.allowPackets == 0 { + return ep.err + } + ep.allowPackets-- + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + ep.WrittenPackets = append(ep.WrittenPackets, pkt) + return nil } +// Attach implements LinkEndpoint.Attach. +func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} + +// IsAttached implements LinkEndpoint.IsAttached. +func (*MockLinkEndpoint) IsAttached() bool { return false } + +// Wait implements LinkEndpoint.Wait. +func (*MockLinkEndpoint) Wait() {} + +// ARPHardwareType implements LinkEndpoint.ARPHardwareType. +func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } + +// AddHeader implements LinkEndpoint.AddHeader. +func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} + // MakeRandPkt generates a randomized packet. transportHeaderLength indicates // how many random bytes will be copied in the Transport Header. // extraHeaderReserveLength indicates how much extra space will be reserved for |