diff options
-rw-r--r-- | pkg/sentry/socket/netfilter/ipv4.go | 30 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/ipv6.go | 31 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/netfilter.go | 9 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/owner_matcher.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/tcp_matcher.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/udp_matcher.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 24 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_types.go | 64 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/iptables_test.go | 336 | ||||
-rw-r--r-- | test/iptables/filter_input.go | 198 | ||||
-rw-r--r-- | test/iptables/iptables_test.go | 24 | ||||
-rw-r--r-- | test/iptables/iptables_util.go | 2 |
17 files changed, 664 insertions, 93 deletions
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index 70c561cce..2f913787b 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -15,7 +15,6 @@ package netfilter import ( - "bytes" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" @@ -220,18 +219,6 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) } - n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) - if n == -1 { - n = len(iptip.OutputInterface) - } - ifname := string(iptip.OutputInterface[:n]) - - n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) - if n == -1 { - n = len(iptip.OutputInterfaceMask) - } - ifnameMask := string(iptip.OutputInterfaceMask[:n]) - return stack.IPHeaderFilter{ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), // A Protocol value of 0 indicates all protocols match. @@ -242,8 +229,11 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { Src: tcpip.Address(iptip.Src[:]), SrcMask: tcpip.Address(iptip.SrcMask[:]), SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0, - OutputInterface: ifname, - OutputInterfaceMask: ifnameMask, + InputInterface: string(trimNullBytes(iptip.InputInterface[:])), + InputInterfaceMask: string(trimNullBytes(iptip.InputInterfaceMask[:])), + InputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_IN != 0, + OutputInterface: string(trimNullBytes(iptip.OutputInterface[:])), + OutputInterfaceMask: string(trimNullBytes(iptip.OutputInterfaceMask[:])), OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, }, nil } @@ -254,12 +244,12 @@ func containsUnsupportedFields4(iptip linux.IPTIP) bool { // - Dst and DstMask // - Src and SrcMask // - The inverse destination IP check flag + // - InputInterface, InputInterfaceMask and its inverse. // - OutputInterface, OutputInterfaceMask and its inverse. - var emptyInterface = [linux.IFNAMSIZ]byte{} + const flagMask = 0 // Disable any supported inverse flags. - inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT) - return iptip.InputInterface != emptyInterface || - iptip.InputInterfaceMask != emptyInterface || - iptip.Flags != 0 || + const inverseMask = linux.IPT_INV_DSTIP | linux.IPT_INV_SRCIP | + linux.IPT_INV_VIA_IN | linux.IPT_INV_VIA_OUT + return iptip.Flags&^flagMask != 0 || iptip.InverseFlags&^inverseMask != 0 } diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 5dbb604f0..263d9d3b5 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -15,7 +15,6 @@ package netfilter import ( - "bytes" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" @@ -223,18 +222,6 @@ func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) { return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) } - n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) - if n == -1 { - n = len(iptip.OutputInterface) - } - ifname := string(iptip.OutputInterface[:n]) - - n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) - if n == -1 { - n = len(iptip.OutputInterfaceMask) - } - ifnameMask := string(iptip.OutputInterfaceMask[:n]) - return stack.IPHeaderFilter{ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), // In ip6tables a flag controls whether to check the protocol. @@ -245,8 +232,11 @@ func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) { Src: tcpip.Address(iptip.Src[:]), SrcMask: tcpip.Address(iptip.SrcMask[:]), SrcInvert: iptip.InverseFlags&linux.IP6T_INV_SRCIP != 0, - OutputInterface: ifname, - OutputInterfaceMask: ifnameMask, + InputInterface: string(trimNullBytes(iptip.InputInterface[:])), + InputInterfaceMask: string(trimNullBytes(iptip.InputInterfaceMask[:])), + InputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_IN != 0, + OutputInterface: string(trimNullBytes(iptip.OutputInterface[:])), + OutputInterfaceMask: string(trimNullBytes(iptip.OutputInterfaceMask[:])), OutputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_OUT != 0, }, nil } @@ -257,14 +247,13 @@ func containsUnsupportedFields6(iptip linux.IP6TIP) bool { // - Dst and DstMask // - Src and SrcMask // - The inverse destination IP check flag + // - InputInterface, InputInterfaceMask and its inverse. // - OutputInterface, OutputInterfaceMask and its inverse. - var emptyInterface = [linux.IFNAMSIZ]byte{} - flagMask := uint8(linux.IP6T_F_PROTO) + const flagMask = linux.IP6T_F_PROTO // Disable any supported inverse flags. - inverseMask := uint8(linux.IP6T_INV_DSTIP) | uint8(linux.IP6T_INV_SRCIP) | uint8(linux.IP6T_INV_VIA_OUT) - return iptip.InputInterface != emptyInterface || - iptip.InputInterfaceMask != emptyInterface || - iptip.Flags&^flagMask != 0 || + const inverseMask = linux.IP6T_INV_DSTIP | linux.IP6T_INV_SRCIP | + linux.IP6T_INV_VIA_IN | linux.IP6T_INV_VIA_OUT + return iptip.Flags&^flagMask != 0 || iptip.InverseFlags&^inverseMask != 0 || iptip.TOS != 0 } diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 26bd1abd4..7ae18b2a3 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -17,6 +17,7 @@ package netfilter import ( + "bytes" "errors" "fmt" @@ -393,3 +394,11 @@ func TargetRevision(t *kernel.Task, revPtr usermem.Addr, netProto tcpip.NetworkP rev.Revision = maxSupported return rev, nil } + +func trimNullBytes(b []byte) []byte { + n := bytes.IndexByte(b, 0) + if n == -1 { + n = len(b) + } + return b[:n] +} diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 69d13745e..176fa6116 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -112,7 +112,7 @@ func (*OwnerMatcher) Name() string { } // Match implements Matcher.Match. -func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { +func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { // Support only for OUTPUT chain. // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also. if hook != stack.Output { diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 352c51390..2740697b3 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -96,7 +96,7 @@ func (*TCPMatcher) Name() string { } // Match implements Matcher.Match. -func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { +func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. switch pkt.NetworkProtocolNumber { diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index c88d8268d..466d5395d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -93,7 +93,7 @@ func (*UDPMatcher) Name() string { } // Match implements Matcher.Match. -func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { +func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. switch pkt.NetworkProtocolNumber { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index f295a9192..2ccfa0822 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -321,8 +321,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -436,10 +436,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } } - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -611,7 +611,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. if !e.nic.IsLoopback() { - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -698,7 +699,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index af6f22598..dac7cbfd4 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -2656,7 +2656,7 @@ func (*limitedMatcher) Name() string { } // Match implements Matcher.Match. -func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { if lm.limit == 0 { return true, false } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 37884505e..40176594e 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -631,8 +631,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -747,8 +747,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -897,7 +897,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. if !e.nic.IsLoopback() { - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -955,7 +956,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index aa892d043..7f48962d2 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -2575,7 +2575,7 @@ func (*limitedMatcher) Name() string { } // Match implements Matcher.Match. -func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { if lm.limit == 0 { return true, false } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 09c7811fa..04af933a6 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -267,11 +267,11 @@ const ( // dropped. // // TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from -// which address and nicName can be gathered. Currently, address is only -// needed for prerouting and nicName is only needed for output. +// which address can be gathered. Currently, address is only needed for +// prerouting. // // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { return true } @@ -302,7 +302,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer table = it.v4Tables[tableID] } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -385,10 +385,10 @@ func (it *IPTables) startReaper(interval time.Duration) { // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok { + if ok := it.Check(hook, pkt, gso, r, "", inNicName, outNicName); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -408,11 +408,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { case RuleAccept: return chainAccept @@ -429,7 +429,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, inNicName, outNicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -455,11 +455,11 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. - if !rule.Filter.match(pkt, hook, nicName) { + if !rule.Filter.match(pkt, hook, inNicName, outNicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -467,7 +467,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, pkt, "") + matches, hotdrop := matcher.Match(hook, pkt, inNicName, outNicName) if hotdrop { return RuleDrop, 0 } diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 56a3e7861..fd9d61e39 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -210,8 +210,19 @@ type IPHeaderFilter struct { // filter will match packets that fail the source comparison. SrcInvert bool - // OutputInterface matches the name of the outgoing interface for the - // packet. + // InputInterface matches the name of the incoming interface for the packet. + InputInterface string + + // InputInterfaceMask masks the characters of the interface name when + // comparing with InputInterface. + InputInterfaceMask string + + // InputInterfaceInvert inverts the meaning of incoming interface check, + // i.e. when true the filter will match packets that fail the incoming + // interface comparison. + InputInterfaceInvert bool + + // OutputInterface matches the name of the outgoing interface for the packet. OutputInterface string // OutputInterfaceMask masks the characters of the interface name when @@ -228,7 +239,7 @@ type IPHeaderFilter struct { // // Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4 // or IPv6 header length. -func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool { +func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool { // Extract header fields. var ( // TODO(gvisor.dev/issue/170): Support other filter fields. @@ -264,26 +275,35 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo return false } - // Check the output interface. - // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING - // hooks after supported. - if hook == Output { - n := len(fl.OutputInterface) - if n == 0 { - return true - } - - // If the interface name ends with '+', any interface which - // begins with the name should be matched. - ifName := fl.OutputInterface - matches := nicName == ifName - if strings.HasSuffix(ifName, "+") { - matches = strings.HasPrefix(nicName, ifName[:n-1]) - } - return fl.OutputInterfaceInvert != matches + switch hook { + case Prerouting, Input: + return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) + case Output: + return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) + case Forward, Postrouting: + // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING + // hooks after supported. + return true + default: + panic(fmt.Sprintf("unknown hook: %d", hook)) } +} - return true +func matchIfName(nicName string, ifName string, invert bool) bool { + n := len(ifName) + if n == 0 { + // If the interface name is omitted in the filter, any interface will match. + return true + } + // If the interface name ends with '+', any interface which begins with the + // name should be matched. + var matches bool + if strings.HasSuffix(ifName, "+") { + matches = strings.HasPrefix(nicName, ifName[:n-1]) + } else { + matches = nicName == ifName + } + return matches != invert } // NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header @@ -320,7 +340,7 @@ type Matcher interface { // used for suspicious packets. // // Precondition: packet.NetworkHeader is set. - Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) + Match(hook Hook, packet *PacketBuffer, inputInterfaceName, outputInterfaceName string) (matches bool, hotdrop bool) } // A Target is the interface for taking an action for a packet. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 1742a178d..218b218e7 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -7,6 +7,7 @@ go_test( size = "small", srcs = [ "forward_test.go", + "iptables_test.go", "link_resolution_test.go", "loopback_test.go", "multicast_broadcast_test.go", diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go new file mode 100644 index 000000000..21a8dd291 --- /dev/null +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -0,0 +1,336 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type inputIfNameMatcher struct { + name string +} + +var _ stack.Matcher = (*inputIfNameMatcher)(nil) + +func (*inputIfNameMatcher) Name() string { + return "inputIfNameMatcher" +} + +func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) { + return (hook == stack.Input && im.name != "" && im.name == inNicName), false +} + +const ( + nicID = 1 + nicName = "nic1" + anotherNicName = "nic2" + linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + srcAddrV4 = "\x0a\x00\x00\x01" + dstAddrV4 = "\x0a\x00\x00\x02" + srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + payloadSize = 20 +) + +func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) { + t.Helper() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + }) + e := channel.New(0, header.IPv6MinimumMTU, linkAddr) + nicOpts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) + } + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err) + } + return s, e +} + +func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) { + t.Helper() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + e := channel.New(0, header.IPv4MinimumMTU, linkAddr) + nicOpts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) + } + if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err) + } + return s, e +} + +func genPacketV6() *stack.PacketBuffer { + pktSize := header.IPv6MinimumSize + payloadSize + hdr := buffer.NewPrependable(pktSize) + ip := header.IPv6(hdr.Prepend(pktSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: payloadSize, + TransportProtocol: 99, + HopLimit: 255, + SrcAddr: srcAddrV6, + DstAddr: dstAddrV6, + }) + vv := hdr.View().ToVectorisedView() + return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) +} + +func genPacketV4() *stack.PacketBuffer { + pktSize := header.IPv4MinimumSize + payloadSize + hdr := buffer.NewPrependable(pktSize) + ip := header.IPv4(hdr.Prepend(pktSize)) + ip.Encode(&header.IPv4Fields{ + TOS: 0, + TotalLength: uint16(pktSize), + ID: 1, + Flags: 0, + FragmentOffset: 16, + TTL: 48, + Protocol: 99, + SrcAddr: srcAddrV4, + DstAddr: dstAddrV4, + }) + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + vv := hdr.View().ToVectorisedView() + return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) +} + +func TestIPTablesStatsForInput(t *testing.T) { + tests := []struct { + name string + setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint) + setupFilter func(*testing.T, *stack.Stack) + genPacket func() *stack.PacketBuffer + proto tcpip.NetworkProtocolNumber + expectReceived int + expectInputDropped int + }{ + { + name: "IPv6 Accept", + setupStack: genStackV6, + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv4 Accept", + setupStack: genStackV4, + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv6 Drop (input interface matches)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv4 Drop (input interface matches)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv6 Accept (input interface does not match)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv4 Accept (input interface does not match)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv6 Drop (input interface does not match but invert is true)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ + InputInterface: anotherNicName, + InputInterfaceInvert: true, + } + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv4 Drop (input interface does not match but invert is true)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ + InputInterface: anotherNicName, + InputInterfaceInvert: true, + } + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv6 Accept (input interface does not match using a matcher)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv4 Accept (input interface does not match using a matcher)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, e := test.setupStack(t) + test.setupFilter(t, s) + e.InjectInbound(test.proto, test.genPacket()) + + if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived { + t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived) + } + if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped { + t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped) + } + }) + } +} diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index 37a1a6694..c47660026 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -51,6 +51,12 @@ func init() { RegisterTestCase(FilterInputInvertDestination{}) RegisterTestCase(FilterInputSource{}) RegisterTestCase(FilterInputInvertSource{}) + RegisterTestCase(FilterInputInterfaceAccept{}) + RegisterTestCase(FilterInputInterfaceDrop{}) + RegisterTestCase(FilterInputInterface{}) + RegisterTestCase(FilterInputInterfaceBeginsWith{}) + RegisterTestCase(FilterInputInterfaceInvertDrop{}) + RegisterTestCase(FilterInputInterfaceInvertAccept{}) } // FilterInputDropUDP tests that we can drop UDP traffic. @@ -744,3 +750,195 @@ func (FilterInputInvertSource) ContainerAction(ctx context.Context, ip net.IP, i func (FilterInputInvertSource) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { return sendUDPLoop(ctx, ip, acceptPort) } + +// FilterInputInterfaceAccept tests that packets are accepted from interface +// matching the iptables rule. +type FilterInputInterfaceAccept struct{ localCase } + +var _ TestCase = FilterInputInterfaceAccept{} + +// Name implements TestCase.Name. +func (FilterInputInterfaceAccept) Name() string { + return "FilterInputInterfaceAccept" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInterfaceAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + ifname, ok := getInterfaceName() + if !ok { + return fmt.Errorf("no interface is present, except loopback") + } + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-i", ifname, "-j", "ACCEPT"); err != nil { + return err + } + if err := listenUDP(ctx, acceptPort); err != nil { + return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %w", acceptPort, err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInterfaceAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputInterfaceDrop tests that packets are dropped from interface +// matching the iptables rule. +type FilterInputInterfaceDrop struct{ localCase } + +var _ TestCase = FilterInputInterfaceDrop{} + +// Name implements TestCase.Name. +func (FilterInputInterfaceDrop) Name() string { + return "FilterInputInterfaceDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInterfaceDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + ifname, ok := getInterfaceName() + if !ok { + return fmt.Errorf("no interface is present, except loopback") + } + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-i", ifname, "-j", "DROP"); err != nil { + return err + } + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return nil + } + return fmt.Errorf("error reading: %w", err) + } + return fmt.Errorf("packets should have been dropped, but got a packet") +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInterfaceDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputInterface tests that packets are not dropped from interface which +// is not matching the interface name in the iptables rule. +type FilterInputInterface struct{ localCase } + +var _ TestCase = FilterInputInterface{} + +// Name implements TestCase.Name. +func (FilterInputInterface) Name() string { + return "FilterInputInterface" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInterface) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-i", "lo", "-j", "DROP"); err != nil { + return err + } + if err := listenUDP(ctx, acceptPort); err != nil { + return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %w", acceptPort, err) + } + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInterface) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputInterfaceBeginsWith tests that packets are dropped from an +// interface which begins with the given interface name. +type FilterInputInterfaceBeginsWith struct{ localCase } + +var _ TestCase = FilterInputInterfaceBeginsWith{} + +// Name implements TestCase.Name. +func (FilterInputInterfaceBeginsWith) Name() string { + return "FilterInputInterfaceBeginsWith" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInterfaceBeginsWith) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-i", "e+", "-j", "DROP"); err != nil { + return err + } + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return nil + } + return fmt.Errorf("error reading: %w", err) + } + return fmt.Errorf("packets should have been dropped, but got a packet") +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInterfaceBeginsWith) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputInterfaceInvertDrop tests that we selectively drop packets from +// interface not matching the interface name. +type FilterInputInterfaceInvertDrop struct{ baseCase } + +var _ TestCase = FilterInputInterfaceInvertDrop{} + +// Name implements TestCase.Name. +func (FilterInputInterfaceInvertDrop) Name() string { + return "FilterInputInterfaceInvertDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInterfaceInvertDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "!", "-i", "lo", "-j", "DROP"); err != nil { + return err + } + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return nil + } + return fmt.Errorf("error reading: %w", err) + } + return fmt.Errorf("connection on port %d should not be accepted, but was accepted", acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInterfaceInvertDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err != nil { + var operr *net.OpError + if errors.As(err, &operr) && operr.Timeout() { + return nil + } + return fmt.Errorf("error connecting: %w", err) + } + return fmt.Errorf("connection destined to port %d should not be accepted, but was accepted", acceptPort) +} + +// FilterInputInterfaceInvertAccept tests that we can selectively accept packets +// not matching the specific incoming interface. +type FilterInputInterfaceInvertAccept struct{ baseCase } + +var _ TestCase = FilterInputInterfaceInvertAccept{} + +// Name implements TestCase.Name. +func (FilterInputInterfaceInvertAccept) Name() string { + return "FilterInputInterfaceInvertAccept" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInterfaceInvertAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "!", "-i", "lo", "-j", "ACCEPT"); err != nil { + return err + } + return listenTCP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInterfaceInvertAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) +} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 9a4f60a9a..ef92e3fff 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -392,6 +392,30 @@ func TestInputInvertSource(t *testing.T) { singleTest(t, FilterInputInvertSource{}) } +func TestInputInterfaceAccept(t *testing.T) { + singleTest(t, FilterInputInterfaceAccept{}) +} + +func TestInputInterfaceDrop(t *testing.T) { + singleTest(t, FilterInputInterfaceDrop{}) +} + +func TestInputInterface(t *testing.T) { + singleTest(t, FilterInputInterface{}) +} + +func TestInputInterfaceBeginsWith(t *testing.T) { + singleTest(t, FilterInputInterfaceBeginsWith{}) +} + +func TestInputInterfaceInvertDrop(t *testing.T) { + singleTest(t, FilterInputInterfaceInvertDrop{}) +} + +func TestInputInterfaceInvertAccept(t *testing.T) { + singleTest(t, FilterInputInterfaceInvertAccept{}) +} + func TestFilterAddrs(t *testing.T) { tcs := []struct { ipv6 bool diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index a6ec5cca3..4cd770a65 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -171,7 +171,7 @@ func connectTCP(ctx context.Context, ip net.IP, port int) error { return err } if err := testutil.PollContext(ctx, callback); err != nil { - return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %v", port, err) + return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %w", port, err) } return nil |