diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 208 |
1 files changed, 208 insertions, 0 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 5e50558e8..c1a560914 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -1046,3 +1046,211 @@ func TestReceiveFragments(t *testing.T) { }) } } + +func TestWritePacketsStats(t *testing.T) { + const nPackets = 3 + tests := []struct { + name string + setup func(*testing.T, *stack.Stack) + linkEP stack.LinkEndpoint + expectSent int + }{ + { + name: "Accept all", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + linkEP: &limitedEP{nPackets}, + expectSent: nPackets, + }, { + name: "Accept all with error", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + linkEP: &limitedEP{nPackets - 1}, + expectSent: nPackets - 1, + }, { + name: "Drop all", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = stack.DropTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + linkEP: &limitedEP{nPackets}, + expectSent: 0, + }, { + name: "Drop some", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule that matches only 1 + // of the 3 packets. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + linkEP: &limitedEP{nPackets}, + expectSent: nPackets - 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + rt := buildRoute(t, nil, test.linkEP) + + var pbl stack.PacketBufferList + for i := 0; i < nPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), + Data: buffer.NewView(1).ToVectorisedView(), + }) + pkt.TransportHeader().Push(header.UDPMinimumSize) + pbl.PushBack(pkt) + } + + test.setup(t, rt.Stack()) + + nWritten, err := rt.WritePackets(nil, pbl, stack.NetworkHeaderParams{}) + if err != nil { + t.Fatal(err) + } + + got := int(rt.Stats().IP.PacketsSent.Value()) + if got != test.expectSent { + t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + } + if got != nWritten { + t.Errorf("sent %d packets, WritePackets returned %d", got, nWritten) + } + }) + } +} + +func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) + s.CreateNIC(1, linkEP) + const ( + src = "\x10\x00\x00\x01" + dst = "\x10\x00\x00\x02" + ) + s.AddAddress(1, ipv4.ProtocolNumber, src) + { + subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + rt, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute got %v, want %v", err, nil) + } + return rt +} + +// limitedEP is a link endpoint that writes up to a certain number of packets +// before returning errors. +type limitedEP struct { + limit int +} + +// MTU implements LinkEndpoint.MTU. +func (*limitedEP) MTU() uint32 { return 0 } + +// Capabilities implements LinkEndpoint.Capabilities. +func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 } + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*limitedEP) MaxHeaderLength() uint16 { return 0 } + +// LinkAddress implements LinkEndpoint.LinkAddress. +func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" } + +// WritePacket implements LinkEndpoint.WritePacket. +func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + if ep.limit == 0 { + return tcpip.ErrInvalidEndpointState + } + ep.limit-- + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + if ep.limit == 0 { + return 0, tcpip.ErrInvalidEndpointState + } + nWritten := ep.limit + if nWritten > pkts.Len() { + nWritten = pkts.Len() + } + ep.limit -= nWritten + return nWritten, nil +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { + if ep.limit == 0 { + return tcpip.ErrInvalidEndpointState + } + ep.limit-- + return nil +} + +// Attach implements LinkEndpoint.Attach. +func (*limitedEP) Attach(_ stack.NetworkDispatcher) {} + +// IsAttached implements LinkEndpoint.IsAttached. +func (*limitedEP) IsAttached() bool { return false } + +// Wait implements LinkEndpoint.Wait. +func (*limitedEP) Wait() {} + +// ARPHardwareType implements LinkEndpoint.ARPHardwareType. +func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther } + +// AddHeader implements LinkEndpoint.AddHeader. +func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} + +// limitedMatcher is an iptables matcher that matches after a certain number of +// packets are checked against it. +type limitedMatcher struct { + limit int +} + +// Name implements Matcher.Name. +func (*limitedMatcher) Name() string { + return "limitedMatcher" +} + +// Match implements Matcher.Match. +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { + if lm.limit == 0 { + return true, false + } + lm.limit-- + return false, false +} |