diff options
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 187 |
1 files changed, 40 insertions, 147 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index b14bc98e8..86187aba8 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -17,6 +17,7 @@ package ipv4_test import ( "bytes" "encoding/hex" + "math" "testing" "github.com/google/go-cmp/cmp" @@ -160,47 +161,6 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI } } -type testRoute struct { - stack.Route - - linkEP *testutil.TestEndpoint -} - -func buildTestRoute(t *testing.T, ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) testRoute { - // Make the packet and write it. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - }) - testEP := testutil.NewTestEndpoint(ep, packetCollectorErrors) - s.CreateNIC(1, testEP) - const ( - src = "\x10\x00\x00\x01" - dst = "\x10\x00\x00\x02" - ) - s.AddAddress(1, ipv4.ProtocolNumber, src) - { - subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}) - } - r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute got %v, want %v", err, nil) - } - t.Cleanup(func() { - testEP.Close() - }) - return testRoute{ - Route: r, - linkEP: testEP, - } -} - func TestFragmentation(t *testing.T) { var manyPayloadViewsSizes [1000]int for i := range manyPayloadViewsSizes { @@ -228,7 +188,8 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - r := buildTestRoute(t, channel.New(0, ft.mtu, ""), nil) + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber) source := pkt.Clone() err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ @@ -237,16 +198,16 @@ func TestFragmentation(t *testing.T) { TOS: stack.DefaultTOS, }, pkt) if err != nil { - t.Errorf("err got %v, want %v", err, nil) + t.Errorf("got err = %s, want = nil", err) } - if got, want := len(r.linkEP.WrittenPackets), ft.expectedFrags; got != want { - t.Errorf("len(r.linkEP.WrittenPackets) got %d, want %d", got, want) + if got := len(ep.WrittenPackets); got != ft.expectedFrags { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags) } - if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { - t.Errorf("no errors yet len(r.linkEP.WrittenPackets) got %d, want %d", got, want) + if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { + t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want) } - compareFragments(t, r.linkEP.WrittenPackets, source, ft.mtu) + compareFragments(t, ep.WrittenPackets, source, ft.mtu) }) } } @@ -259,35 +220,30 @@ func TestFragmentationErrors(t *testing.T) { mtu uint32 transportHeaderLength int payloadViewsSizes []int - packetCollectorErrors []*tcpip.Error + err *tcpip.Error + allowPackets int }{ - {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}}, - {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, + {"NoFrag", 2000, 0, []int{1000}, tcpip.ErrAborted, 0}, + {"ErrorOnFirstFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 0}, + {"ErrorOnSecondFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 1}, + {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, tcpip.ErrAborted, 0}, } for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - r := buildTestRoute(t, channel.New(0, ft.mtu, ""), ft.packetCollectorErrors) + ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets) + r := buildRoute(t, ep) pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, }, pkt) - for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { - if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { - t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) - } - } - // We only need to check that last error because all the ones before are - // nil. - if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want { - t.Errorf("err got %v, want %v", got, want) + if err != ft.err { + t.Errorf("got WritePacket() = %s, want = %s", err, ft.err) } - if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { - t.Errorf("after linkEP error len(result) got %d, want %d", got, want) + if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want) } }) } @@ -1052,7 +1008,7 @@ func TestWriteStats(t *testing.T) { tests := []struct { name string setup func(*testing.T, *stack.Stack) - linkEP func() stack.LinkEndpoint + allowPackets int expectSent int expectDropped int expectWritten int @@ -1061,7 +1017,7 @@ func TestWriteStats(t *testing.T) { name: "Accept all", // No setup needed, tables accept everything by default. setup: func(*testing.T, *stack.Stack) {}, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + allowPackets: math.MaxInt32, expectSent: nPackets, expectDropped: 0, expectWritten: nPackets, @@ -1069,7 +1025,7 @@ func TestWriteStats(t *testing.T) { name: "Accept all with error", // No setup needed, tables accept everything by default. setup: func(*testing.T, *stack.Stack) {}, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} }, + allowPackets: nPackets - 1, expectSent: nPackets - 1, expectDropped: 0, expectWritten: nPackets - 1, @@ -1086,10 +1042,10 @@ func TestWriteStats(t *testing.T) { ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = stack.DropTarget{} if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %v", err) + t.Fatalf("failed to replace table: %s", err) } }, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + allowPackets: math.MaxInt32, expectSent: 0, expectDropped: nPackets, expectWritten: nPackets, @@ -1111,10 +1067,10 @@ func TestWriteStats(t *testing.T) { // Make sure the next rule is ACCEPT. filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %v", err) + t.Fatalf("failed to replace table: %s", err) } }, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + allowPackets: math.MaxInt32, expectSent: nPackets - 1, expectDropped: 1, expectWritten: nPackets, @@ -1150,7 +1106,8 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - rt := buildRoute(t, nil, test.linkEP()) + ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets) + rt := buildRoute(t, ep) var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { @@ -1181,101 +1138,37 @@ func TestWriteStats(t *testing.T) { } } -func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route { +func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, }) - s.CreateNIC(1, linkEP) + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC(1, _) failed: %s", err) + } const ( src = "\x10\x00\x00\x01" dst = "\x10\x00\x00\x02" ) - s.AddAddress(1, ipv4.ProtocolNumber, src) + if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { + t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err) + } { subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) if err != nil { - t.Fatal(err) + t.Fatalf("NewSubnet(_, _) failed: %v", err) } s.SetRouteTable([]tcpip.Route{{ Destination: subnet, NIC: 1, }}) } - rt, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) + rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("s.FindRoute got %v, want %v", err, nil) + t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err) } return rt } -// limitedEP is a link endpoint that writes up to a certain number of packets -// before returning errors. -type limitedEP struct { - limit int -} - -// MTU implements LinkEndpoint.MTU. -func (*limitedEP) MTU() uint32 { - // Give an MTU that won't cause fragmentation for IPv4+UDP. - return header.IPv4MinimumSize + header.UDPMinimumSize -} - -// Capabilities implements LinkEndpoint.Capabilities. -func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 } - -// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. -func (*limitedEP) MaxHeaderLength() uint16 { return 0 } - -// LinkAddress implements LinkEndpoint.LinkAddress. -func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" } - -// WritePacket implements LinkEndpoint.WritePacket. -func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { - if ep.limit == 0 { - return tcpip.ErrInvalidEndpointState - } - ep.limit-- - return nil -} - -// WritePackets implements LinkEndpoint.WritePackets. -func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - if ep.limit == 0 { - return 0, tcpip.ErrInvalidEndpointState - } - nWritten := ep.limit - if nWritten > pkts.Len() { - nWritten = pkts.Len() - } - ep.limit -= nWritten - return nWritten, nil -} - -// WriteRawPacket implements LinkEndpoint.WriteRawPacket. -func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { - if ep.limit == 0 { - return tcpip.ErrInvalidEndpointState - } - ep.limit-- - return nil -} - -// Attach implements LinkEndpoint.Attach. -func (*limitedEP) Attach(_ stack.NetworkDispatcher) {} - -// IsAttached implements LinkEndpoint.IsAttached. -func (*limitedEP) IsAttached() bool { return false } - -// Wait implements LinkEndpoint.Wait. -func (*limitedEP) Wait() {} - -// ARPHardwareType implements LinkEndpoint.ARPHardwareType. -func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther } - -// AddHeader implements LinkEndpoint.AddHeader. -func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { -} - // limitedMatcher is an iptables matcher that matches after a certain number of // packets are checked against it. type limitedMatcher struct { |