diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-02-05 18:41:37 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-02-05 18:44:04 -0800 |
commit | 83b764d9d2193e2e01f3a60792f3468c1843c5a8 (patch) | |
tree | 3cb303660a15cfd0b2150ee3d93966636dbb3054 /pkg/tcpip/tests | |
parent | 120c8e34687129c919ae45263c14b239a0a5d343 (diff) |
Batch write packets after iptables checks
After IPTables checks a batch of packets, we can write packets that are
not dropped or locally destined as a batch instead of individually.
This previously caused a bug since WritePacket* functions expect to take
ownership of passed PacketBuffer{List}. WritePackets assumed the list of
PacketBuffers will not be invalidated when calling WritePacket for each
PacketBuffer in the list, but this is not true. WritePacket may add the
passed PacketBuffer into a different list which would modify the
PacketBuffer in such a way that it no longer points to the next
PacketBuffer to write.
Example: Given a PB list of
PB_a -> PB_b -> PB_c
WritePackets may be iterating over the list and calling WritePacket for
each PB. When WritePacket takes PB_a, it may add it to a new list which
would update pointers such that PB_a no longer points to PB_b.
Test: integration_test.TestIPTableWritePackets
PiperOrigin-RevId: 355969560
Diffstat (limited to 'pkg/tcpip/tests')
-rw-r--r-- | pkg/tcpip/tests/integration/iptables_test.go | 310 |
1 files changed, 310 insertions, 0 deletions
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 21a8dd291..b56706357 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) type inputIfNameMatcher struct { @@ -334,3 +335,312 @@ func TestIPTablesStatsForInput(t *testing.T) { }) } } + +var _ stack.LinkEndpoint = (*channelEndpointWithoutWritePacket)(nil) + +// channelEndpointWithoutWritePacket is a channel endpoint that does not support +// stack.LinkEndpoint.WritePacket. +type channelEndpointWithoutWritePacket struct { + *channel.Endpoint + + t *testing.T +} + +func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { + c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets") + return &tcpip.ErrNotSupported{} +} + +var _ stack.Matcher = (*udpSourcePortMatcher)(nil) + +type udpSourcePortMatcher struct { + port uint16 +} + +func (*udpSourcePortMatcher) Name() string { + return "udpSourcePortMatcher" +} + +func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) { + udp := header.UDP(pkt.TransportHeader().View()) + if len(udp) < header.UDPMinimumSize { + // Drop immediately as the packet is invalid. + return false, true + } + + return udp.SourcePort() == m.port, false +} + +func TestIPTableWritePackets(t *testing.T) { + const ( + nicID = 1 + + dropLocalPort = localPort - 1 + acceptPackets = 2 + dropPackets = 3 + ) + + udpHdr := func(hdr buffer.View, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) { + u := header.UDP(hdr) + u.Encode(&header.UDPFields{ + SrcPort: srcPort, + DstPort: dstPort, + Length: header.UDPMinimumSize, + }) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize) + sum = header.Checksum(hdr, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + } + + tests := []struct { + name string + setupFilter func(*testing.T, *stack.Stack) + genPacket func(*stack.Route) stack.PacketBufferList + proto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectSent uint64 + expectOutputDropped uint64 + }{ + { + name: "IPv4 Accept", + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + + return pkts + }, + proto: header.IPv4ProtocolNumber, + remoteAddr: dstAddrV4, + expectSent: 1, + expectOutputDropped: 0, + }, + { + name: "IPv4 Drop Other Port", + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + + table := stack.Table{ + Rules: []stack.Rule{ + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, + Target: &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + }, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + Underflows: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + } + + if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil { + t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err) + } + }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + for i := 0; i < acceptPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + } + for i := 0; i < dropPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort) + pkts.PushFront(pkt) + } + + return pkts + }, + proto: header.IPv4ProtocolNumber, + remoteAddr: dstAddrV4, + expectSent: acceptPackets, + expectOutputDropped: dropPackets, + }, + { + name: "IPv6 Accept", + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + + return pkts + }, + proto: header.IPv6ProtocolNumber, + remoteAddr: dstAddrV6, + expectSent: 1, + expectOutputDropped: 0, + }, + { + name: "IPv6 Drop Other Port", + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + + table := stack.Table{ + Rules: []stack.Rule{ + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, + Target: &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + }, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + Underflows: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + } + + if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil { + t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err) + } + }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + for i := 0; i < acceptPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + } + for i := 0; i < dropPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort) + pkts.PushFront(pkt) + } + + return pkts + }, + proto: header.IPv6ProtocolNumber, + remoteAddr: dstAddrV6, + expectSent: acceptPackets, + expectOutputDropped: dropPackets, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + e := channelEndpointWithoutWritePacket{ + Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr), + t: t, + } + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err) + } + if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + + test.setupFilter(t, s) + + r, err := s.FindRoute(nicID, "", test.remoteAddr, test.proto, false) + if err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err) + } + defer r.Release() + + pkts := test.genPacket(r) + pktsLen := pkts.Len() + if n, err := r.WritePackets(nil /* gso */, pkts, stack.NetworkHeaderParams{ + Protocol: header.UDPProtocolNumber, + TTL: 64, + }); err != nil { + t.Fatalf("WritePackets(...): %s", err) + } else if n != pktsLen { + t.Fatalf("got WritePackets(...) = %d, want = %d", n, pktsLen) + } + + if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent { + t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent) + } + if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped { + t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped) + } + }) + } +} |