diff options
Diffstat (limited to 'pkg/tcpip/tests')
-rw-r--r-- | pkg/tcpip/tests/integration/iptables_test.go | 310 |
1 files changed, 310 insertions, 0 deletions
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 21a8dd291..b56706357 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) type inputIfNameMatcher struct { @@ -334,3 +335,312 @@ func TestIPTablesStatsForInput(t *testing.T) { }) } } + +var _ stack.LinkEndpoint = (*channelEndpointWithoutWritePacket)(nil) + +// channelEndpointWithoutWritePacket is a channel endpoint that does not support +// stack.LinkEndpoint.WritePacket. +type channelEndpointWithoutWritePacket struct { + *channel.Endpoint + + t *testing.T +} + +func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { + c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets") + return &tcpip.ErrNotSupported{} +} + +var _ stack.Matcher = (*udpSourcePortMatcher)(nil) + +type udpSourcePortMatcher struct { + port uint16 +} + +func (*udpSourcePortMatcher) Name() string { + return "udpSourcePortMatcher" +} + +func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) { + udp := header.UDP(pkt.TransportHeader().View()) + if len(udp) < header.UDPMinimumSize { + // Drop immediately as the packet is invalid. + return false, true + } + + return udp.SourcePort() == m.port, false +} + +func TestIPTableWritePackets(t *testing.T) { + const ( + nicID = 1 + + dropLocalPort = localPort - 1 + acceptPackets = 2 + dropPackets = 3 + ) + + udpHdr := func(hdr buffer.View, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) { + u := header.UDP(hdr) + u.Encode(&header.UDPFields{ + SrcPort: srcPort, + DstPort: dstPort, + Length: header.UDPMinimumSize, + }) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize) + sum = header.Checksum(hdr, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + } + + tests := []struct { + name string + setupFilter func(*testing.T, *stack.Stack) + genPacket func(*stack.Route) stack.PacketBufferList + proto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectSent uint64 + expectOutputDropped uint64 + }{ + { + name: "IPv4 Accept", + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + + return pkts + }, + proto: header.IPv4ProtocolNumber, + remoteAddr: dstAddrV4, + expectSent: 1, + expectOutputDropped: 0, + }, + { + name: "IPv4 Drop Other Port", + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + + table := stack.Table{ + Rules: []stack.Rule{ + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, + Target: &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + }, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + Underflows: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + } + + if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil { + t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err) + } + }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + for i := 0; i < acceptPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + } + for i := 0; i < dropPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort) + pkts.PushFront(pkt) + } + + return pkts + }, + proto: header.IPv4ProtocolNumber, + remoteAddr: dstAddrV4, + expectSent: acceptPackets, + expectOutputDropped: dropPackets, + }, + { + name: "IPv6 Accept", + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + + return pkts + }, + proto: header.IPv6ProtocolNumber, + remoteAddr: dstAddrV6, + expectSent: 1, + expectOutputDropped: 0, + }, + { + name: "IPv6 Drop Other Port", + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + + table := stack.Table{ + Rules: []stack.Rule{ + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, + Target: &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + }, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + Underflows: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + } + + if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil { + t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err) + } + }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + for i := 0; i < acceptPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + } + for i := 0; i < dropPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort) + pkts.PushFront(pkt) + } + + return pkts + }, + proto: header.IPv6ProtocolNumber, + remoteAddr: dstAddrV6, + expectSent: acceptPackets, + expectOutputDropped: dropPackets, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + e := channelEndpointWithoutWritePacket{ + Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr), + t: t, + } + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err) + } + if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + + test.setupFilter(t, s) + + r, err := s.FindRoute(nicID, "", test.remoteAddr, test.proto, false) + if err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err) + } + defer r.Release() + + pkts := test.genPacket(r) + pktsLen := pkts.Len() + if n, err := r.WritePackets(nil /* gso */, pkts, stack.NetworkHeaderParams{ + Protocol: header.UDPProtocolNumber, + TTL: 64, + }); err != nil { + t.Fatalf("WritePackets(...): %s", err) + } else if n != pktsLen { + t.Fatalf("got WritePackets(...) = %d, want = %d", n, pktsLen) + } + + if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent { + t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent) + } + if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped { + t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped) + } + }) + } +} |