summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/tests
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/tests')
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go310
1 files changed, 310 insertions, 0 deletions
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 21a8dd291..b56706357 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
type inputIfNameMatcher struct {
@@ -334,3 +335,312 @@ func TestIPTablesStatsForInput(t *testing.T) {
})
}
}
+
+var _ stack.LinkEndpoint = (*channelEndpointWithoutWritePacket)(nil)
+
+// channelEndpointWithoutWritePacket is a channel endpoint that does not support
+// stack.LinkEndpoint.WritePacket.
+type channelEndpointWithoutWritePacket struct {
+ *channel.Endpoint
+
+ t *testing.T
+}
+
+func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
+ c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets")
+ return &tcpip.ErrNotSupported{}
+}
+
+var _ stack.Matcher = (*udpSourcePortMatcher)(nil)
+
+type udpSourcePortMatcher struct {
+ port uint16
+}
+
+func (*udpSourcePortMatcher) Name() string {
+ return "udpSourcePortMatcher"
+}
+
+func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) {
+ udp := header.UDP(pkt.TransportHeader().View())
+ if len(udp) < header.UDPMinimumSize {
+ // Drop immediately as the packet is invalid.
+ return false, true
+ }
+
+ return udp.SourcePort() == m.port, false
+}
+
+func TestIPTableWritePackets(t *testing.T) {
+ const (
+ nicID = 1
+
+ dropLocalPort = localPort - 1
+ acceptPackets = 2
+ dropPackets = 3
+ )
+
+ udpHdr := func(hdr buffer.View, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) {
+ u := header.UDP(hdr)
+ u.Encode(&header.UDPFields{
+ SrcPort: srcPort,
+ DstPort: dstPort,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize)
+ sum = header.Checksum(hdr, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ }
+
+ tests := []struct {
+ name string
+ setupFilter func(*testing.T, *stack.Stack)
+ genPacket func(*stack.Route) stack.PacketBufferList
+ proto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectSent uint64
+ expectOutputDropped uint64
+ }{
+ {
+ name: "IPv4 Accept",
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+
+ return pkts
+ },
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: dstAddrV4,
+ expectSent: 1,
+ expectOutputDropped: 0,
+ },
+ {
+ name: "IPv4 Drop Other Port",
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
+ Target: &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ Underflows: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ }
+
+ if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil {
+ t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err)
+ }
+ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ for i := 0; i < acceptPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+ for i := 0; i < dropPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+
+ return pkts
+ },
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: dstAddrV4,
+ expectSent: acceptPackets,
+ expectOutputDropped: dropPackets,
+ },
+ {
+ name: "IPv6 Accept",
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+
+ return pkts
+ },
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: dstAddrV6,
+ expectSent: 1,
+ expectOutputDropped: 0,
+ },
+ {
+ name: "IPv6 Drop Other Port",
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
+ Target: &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ Underflows: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ }
+
+ if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil {
+ t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err)
+ }
+ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ for i := 0; i < acceptPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+ for i := 0; i < dropPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+
+ return pkts
+ },
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: dstAddrV6,
+ expectSent: acceptPackets,
+ expectOutputDropped: dropPackets,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ e := channelEndpointWithoutWritePacket{
+ Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr),
+ t: t,
+ }
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ test.setupFilter(t, s)
+
+ r, err := s.FindRoute(nicID, "", test.remoteAddr, test.proto, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err)
+ }
+ defer r.Release()
+
+ pkts := test.genPacket(r)
+ pktsLen := pkts.Len()
+ if n, err := r.WritePackets(nil /* gso */, pkts, stack.NetworkHeaderParams{
+ Protocol: header.UDPProtocolNumber,
+ TTL: 64,
+ }); err != nil {
+ t.Fatalf("WritePackets(...): %s", err)
+ } else if n != pktsLen {
+ t.Fatalf("got WritePackets(...) = %d, want = %d", n, pktsLen)
+ }
+
+ if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent {
+ t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent)
+ }
+ if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped {
+ t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped)
+ }
+ })
+ }
+}