summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network/ipv4
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go187
1 files changed, 40 insertions, 147 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index b14bc98e8..86187aba8 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -17,6 +17,7 @@ package ipv4_test
import (
"bytes"
"encoding/hex"
+ "math"
"testing"
"github.com/google/go-cmp/cmp"
@@ -160,47 +161,6 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
}
}
-type testRoute struct {
- stack.Route
-
- linkEP *testutil.TestEndpoint
-}
-
-func buildTestRoute(t *testing.T, ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) testRoute {
- // Make the packet and write it.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- })
- testEP := testutil.NewTestEndpoint(ep, packetCollectorErrors)
- s.CreateNIC(1, testEP)
- const (
- src = "\x10\x00\x00\x01"
- dst = "\x10\x00\x00\x02"
- )
- s.AddAddress(1, ipv4.ProtocolNumber, src)
- {
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }})
- }
- r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("s.FindRoute got %v, want %v", err, nil)
- }
- t.Cleanup(func() {
- testEP.Close()
- })
- return testRoute{
- Route: r,
- linkEP: testEP,
- }
-}
-
func TestFragmentation(t *testing.T) {
var manyPayloadViewsSizes [1000]int
for i := range manyPayloadViewsSizes {
@@ -228,7 +188,8 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- r := buildTestRoute(t, channel.New(0, ft.mtu, ""), nil)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
source := pkt.Clone()
err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
@@ -237,16 +198,16 @@ func TestFragmentation(t *testing.T) {
TOS: stack.DefaultTOS,
}, pkt)
if err != nil {
- t.Errorf("err got %v, want %v", err, nil)
+ t.Errorf("got err = %s, want = nil", err)
}
- if got, want := len(r.linkEP.WrittenPackets), ft.expectedFrags; got != want {
- t.Errorf("len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
+ if got := len(ep.WrittenPackets); got != ft.expectedFrags {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags)
}
- if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
- t.Errorf("no errors yet len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
+ if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
+ t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
- compareFragments(t, r.linkEP.WrittenPackets, source, ft.mtu)
+ compareFragments(t, ep.WrittenPackets, source, ft.mtu)
})
}
}
@@ -259,35 +220,30 @@ func TestFragmentationErrors(t *testing.T) {
mtu uint32
transportHeaderLength int
payloadViewsSizes []int
- packetCollectorErrors []*tcpip.Error
+ err *tcpip.Error
+ allowPackets int
}{
- {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
- {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
- {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
- {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+ {"NoFrag", 2000, 0, []int{1000}, tcpip.ErrAborted, 0},
+ {"ErrorOnFirstFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 0},
+ {"ErrorOnSecondFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 1},
+ {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, tcpip.ErrAborted, 0},
}
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- r := buildTestRoute(t, channel.New(0, ft.mtu, ""), ft.packetCollectorErrors)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets)
+ r := buildRoute(t, ep)
pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: 42,
TOS: stack.DefaultTOS,
}, pkt)
- for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
- if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
- t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
- }
- }
- // We only need to check that last error because all the ones before are
- // nil.
- if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
- t.Errorf("err got %v, want %v", got, want)
+ if err != ft.err {
+ t.Errorf("got WritePacket() = %s, want = %s", err, ft.err)
}
- if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
- t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
+ if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
})
}
@@ -1052,7 +1008,7 @@ func TestWriteStats(t *testing.T) {
tests := []struct {
name string
setup func(*testing.T, *stack.Stack)
- linkEP func() stack.LinkEndpoint
+ allowPackets int
expectSent int
expectDropped int
expectWritten int
@@ -1061,7 +1017,7 @@ func TestWriteStats(t *testing.T) {
name: "Accept all",
// No setup needed, tables accept everything by default.
setup: func(*testing.T, *stack.Stack) {},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ allowPackets: math.MaxInt32,
expectSent: nPackets,
expectDropped: 0,
expectWritten: nPackets,
@@ -1069,7 +1025,7 @@ func TestWriteStats(t *testing.T) {
name: "Accept all with error",
// No setup needed, tables accept everything by default.
setup: func(*testing.T, *stack.Stack) {},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} },
+ allowPackets: nPackets - 1,
expectSent: nPackets - 1,
expectDropped: 0,
expectWritten: nPackets - 1,
@@ -1086,10 +1042,10 @@ func TestWriteStats(t *testing.T) {
ruleIdx := filter.BuiltinChains[stack.Output]
filter.Rules[ruleIdx].Target = stack.DropTarget{}
if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %v", err)
+ t.Fatalf("failed to replace table: %s", err)
}
},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ allowPackets: math.MaxInt32,
expectSent: 0,
expectDropped: nPackets,
expectWritten: nPackets,
@@ -1111,10 +1067,10 @@ func TestWriteStats(t *testing.T) {
// Make sure the next rule is ACCEPT.
filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %v", err)
+ t.Fatalf("failed to replace table: %s", err)
}
},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ allowPackets: math.MaxInt32,
expectSent: nPackets - 1,
expectDropped: 1,
expectWritten: nPackets,
@@ -1150,7 +1106,8 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- rt := buildRoute(t, nil, test.linkEP())
+ ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
@@ -1181,101 +1138,37 @@ func TestWriteStats(t *testing.T) {
}
}
-func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
})
- s.CreateNIC(1, linkEP)
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC(1, _) failed: %s", err)
+ }
const (
src = "\x10\x00\x00\x01"
dst = "\x10\x00\x00\x02"
)
- s.AddAddress(1, ipv4.ProtocolNumber, src)
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
+ t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err)
+ }
{
subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
if err != nil {
- t.Fatal(err)
+ t.Fatalf("NewSubnet(_, _) failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: subnet,
NIC: 1,
}})
}
- rt, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
+ rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+ t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err)
}
return rt
}
-// limitedEP is a link endpoint that writes up to a certain number of packets
-// before returning errors.
-type limitedEP struct {
- limit int
-}
-
-// MTU implements LinkEndpoint.MTU.
-func (*limitedEP) MTU() uint32 {
- // Give an MTU that won't cause fragmentation for IPv4+UDP.
- return header.IPv4MinimumSize + header.UDPMinimumSize
-}
-
-// Capabilities implements LinkEndpoint.Capabilities.
-func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 }
-
-// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
-func (*limitedEP) MaxHeaderLength() uint16 { return 0 }
-
-// LinkAddress implements LinkEndpoint.LinkAddress.
-func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" }
-
-// WritePacket implements LinkEndpoint.WritePacket.
-func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
- if ep.limit == 0 {
- return tcpip.ErrInvalidEndpointState
- }
- ep.limit--
- return nil
-}
-
-// WritePackets implements LinkEndpoint.WritePackets.
-func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- if ep.limit == 0 {
- return 0, tcpip.ErrInvalidEndpointState
- }
- nWritten := ep.limit
- if nWritten > pkts.Len() {
- nWritten = pkts.Len()
- }
- ep.limit -= nWritten
- return nWritten, nil
-}
-
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
- if ep.limit == 0 {
- return tcpip.ErrInvalidEndpointState
- }
- ep.limit--
- return nil
-}
-
-// Attach implements LinkEndpoint.Attach.
-func (*limitedEP) Attach(_ stack.NetworkDispatcher) {}
-
-// IsAttached implements LinkEndpoint.IsAttached.
-func (*limitedEP) IsAttached() bool { return false }
-
-// Wait implements LinkEndpoint.Wait.
-func (*limitedEP) Wait() {}
-
-// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
-func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther }
-
-// AddHeader implements LinkEndpoint.AddHeader.
-func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
-}
-
// limitedMatcher is an iptables matcher that matches after a certain number of
// packets are checked against it.
type limitedMatcher struct {