diff options
-rw-r--r-- | pkg/sentry/socket/netfilter/ipv4.go | 30 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/ipv6.go | 31 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/netfilter.go | 9 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/owner_matcher.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/tcp_matcher.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/udp_matcher.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 24 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_types.go | 64 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_state_autogen.go | 21 |
11 files changed, 117 insertions, 96 deletions
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index 70c561cce..2f913787b 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -15,7 +15,6 @@ package netfilter import ( - "bytes" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" @@ -220,18 +219,6 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) } - n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) - if n == -1 { - n = len(iptip.OutputInterface) - } - ifname := string(iptip.OutputInterface[:n]) - - n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) - if n == -1 { - n = len(iptip.OutputInterfaceMask) - } - ifnameMask := string(iptip.OutputInterfaceMask[:n]) - return stack.IPHeaderFilter{ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), // A Protocol value of 0 indicates all protocols match. @@ -242,8 +229,11 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { Src: tcpip.Address(iptip.Src[:]), SrcMask: tcpip.Address(iptip.SrcMask[:]), SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0, - OutputInterface: ifname, - OutputInterfaceMask: ifnameMask, + InputInterface: string(trimNullBytes(iptip.InputInterface[:])), + InputInterfaceMask: string(trimNullBytes(iptip.InputInterfaceMask[:])), + InputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_IN != 0, + OutputInterface: string(trimNullBytes(iptip.OutputInterface[:])), + OutputInterfaceMask: string(trimNullBytes(iptip.OutputInterfaceMask[:])), OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, }, nil } @@ -254,12 +244,12 @@ func containsUnsupportedFields4(iptip linux.IPTIP) bool { // - Dst and DstMask // - Src and SrcMask // - The inverse destination IP check flag + // - InputInterface, InputInterfaceMask and its inverse. // - OutputInterface, OutputInterfaceMask and its inverse. - var emptyInterface = [linux.IFNAMSIZ]byte{} + const flagMask = 0 // Disable any supported inverse flags. - inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT) - return iptip.InputInterface != emptyInterface || - iptip.InputInterfaceMask != emptyInterface || - iptip.Flags != 0 || + const inverseMask = linux.IPT_INV_DSTIP | linux.IPT_INV_SRCIP | + linux.IPT_INV_VIA_IN | linux.IPT_INV_VIA_OUT + return iptip.Flags&^flagMask != 0 || iptip.InverseFlags&^inverseMask != 0 } diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 5dbb604f0..263d9d3b5 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -15,7 +15,6 @@ package netfilter import ( - "bytes" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" @@ -223,18 +222,6 @@ func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) { return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) } - n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) - if n == -1 { - n = len(iptip.OutputInterface) - } - ifname := string(iptip.OutputInterface[:n]) - - n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) - if n == -1 { - n = len(iptip.OutputInterfaceMask) - } - ifnameMask := string(iptip.OutputInterfaceMask[:n]) - return stack.IPHeaderFilter{ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), // In ip6tables a flag controls whether to check the protocol. @@ -245,8 +232,11 @@ func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) { Src: tcpip.Address(iptip.Src[:]), SrcMask: tcpip.Address(iptip.SrcMask[:]), SrcInvert: iptip.InverseFlags&linux.IP6T_INV_SRCIP != 0, - OutputInterface: ifname, - OutputInterfaceMask: ifnameMask, + InputInterface: string(trimNullBytes(iptip.InputInterface[:])), + InputInterfaceMask: string(trimNullBytes(iptip.InputInterfaceMask[:])), + InputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_IN != 0, + OutputInterface: string(trimNullBytes(iptip.OutputInterface[:])), + OutputInterfaceMask: string(trimNullBytes(iptip.OutputInterfaceMask[:])), OutputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_OUT != 0, }, nil } @@ -257,14 +247,13 @@ func containsUnsupportedFields6(iptip linux.IP6TIP) bool { // - Dst and DstMask // - Src and SrcMask // - The inverse destination IP check flag + // - InputInterface, InputInterfaceMask and its inverse. // - OutputInterface, OutputInterfaceMask and its inverse. - var emptyInterface = [linux.IFNAMSIZ]byte{} - flagMask := uint8(linux.IP6T_F_PROTO) + const flagMask = linux.IP6T_F_PROTO // Disable any supported inverse flags. - inverseMask := uint8(linux.IP6T_INV_DSTIP) | uint8(linux.IP6T_INV_SRCIP) | uint8(linux.IP6T_INV_VIA_OUT) - return iptip.InputInterface != emptyInterface || - iptip.InputInterfaceMask != emptyInterface || - iptip.Flags&^flagMask != 0 || + const inverseMask = linux.IP6T_INV_DSTIP | linux.IP6T_INV_SRCIP | + linux.IP6T_INV_VIA_IN | linux.IP6T_INV_VIA_OUT + return iptip.Flags&^flagMask != 0 || iptip.InverseFlags&^inverseMask != 0 || iptip.TOS != 0 } diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 26bd1abd4..7ae18b2a3 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -17,6 +17,7 @@ package netfilter import ( + "bytes" "errors" "fmt" @@ -393,3 +394,11 @@ func TargetRevision(t *kernel.Task, revPtr usermem.Addr, netProto tcpip.NetworkP rev.Revision = maxSupported return rev, nil } + +func trimNullBytes(b []byte) []byte { + n := bytes.IndexByte(b, 0) + if n == -1 { + n = len(b) + } + return b[:n] +} diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 69d13745e..176fa6116 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -112,7 +112,7 @@ func (*OwnerMatcher) Name() string { } // Match implements Matcher.Match. -func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { +func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { // Support only for OUTPUT chain. // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also. if hook != stack.Output { diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 352c51390..2740697b3 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -96,7 +96,7 @@ func (*TCPMatcher) Name() string { } // Match implements Matcher.Match. -func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { +func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. switch pkt.NetworkProtocolNumber { diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index c88d8268d..466d5395d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -93,7 +93,7 @@ func (*UDPMatcher) Name() string { } // Match implements Matcher.Match. -func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { +func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. switch pkt.NetworkProtocolNumber { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index f295a9192..2ccfa0822 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -321,8 +321,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -436,10 +436,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } } - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -611,7 +611,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. if !e.nic.IsLoopback() { - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -698,7 +699,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 37884505e..40176594e 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -631,8 +631,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -747,8 +747,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -897,7 +897,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. if !e.nic.IsLoopback() { - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -955,7 +956,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 09c7811fa..04af933a6 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -267,11 +267,11 @@ const ( // dropped. // // TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from -// which address and nicName can be gathered. Currently, address is only -// needed for prerouting and nicName is only needed for output. +// which address can be gathered. Currently, address is only needed for +// prerouting. // // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { return true } @@ -302,7 +302,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer table = it.v4Tables[tableID] } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -385,10 +385,10 @@ func (it *IPTables) startReaper(interval time.Duration) { // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok { + if ok := it.Check(hook, pkt, gso, r, "", inNicName, outNicName); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -408,11 +408,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { case RuleAccept: return chainAccept @@ -429,7 +429,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, inNicName, outNicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -455,11 +455,11 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. - if !rule.Filter.match(pkt, hook, nicName) { + if !rule.Filter.match(pkt, hook, inNicName, outNicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -467,7 +467,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, pkt, "") + matches, hotdrop := matcher.Match(hook, pkt, inNicName, outNicName) if hotdrop { return RuleDrop, 0 } diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 56a3e7861..fd9d61e39 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -210,8 +210,19 @@ type IPHeaderFilter struct { // filter will match packets that fail the source comparison. SrcInvert bool - // OutputInterface matches the name of the outgoing interface for the - // packet. + // InputInterface matches the name of the incoming interface for the packet. + InputInterface string + + // InputInterfaceMask masks the characters of the interface name when + // comparing with InputInterface. + InputInterfaceMask string + + // InputInterfaceInvert inverts the meaning of incoming interface check, + // i.e. when true the filter will match packets that fail the incoming + // interface comparison. + InputInterfaceInvert bool + + // OutputInterface matches the name of the outgoing interface for the packet. OutputInterface string // OutputInterfaceMask masks the characters of the interface name when @@ -228,7 +239,7 @@ type IPHeaderFilter struct { // // Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4 // or IPv6 header length. -func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool { +func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool { // Extract header fields. var ( // TODO(gvisor.dev/issue/170): Support other filter fields. @@ -264,26 +275,35 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo return false } - // Check the output interface. - // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING - // hooks after supported. - if hook == Output { - n := len(fl.OutputInterface) - if n == 0 { - return true - } - - // If the interface name ends with '+', any interface which - // begins with the name should be matched. - ifName := fl.OutputInterface - matches := nicName == ifName - if strings.HasSuffix(ifName, "+") { - matches = strings.HasPrefix(nicName, ifName[:n-1]) - } - return fl.OutputInterfaceInvert != matches + switch hook { + case Prerouting, Input: + return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) + case Output: + return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) + case Forward, Postrouting: + // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING + // hooks after supported. + return true + default: + panic(fmt.Sprintf("unknown hook: %d", hook)) } +} - return true +func matchIfName(nicName string, ifName string, invert bool) bool { + n := len(ifName) + if n == 0 { + // If the interface name is omitted in the filter, any interface will match. + return true + } + // If the interface name ends with '+', any interface which begins with the + // name should be matched. + var matches bool + if strings.HasSuffix(ifName, "+") { + matches = strings.HasPrefix(nicName, ifName[:n-1]) + } else { + matches = nicName == ifName + } + return matches != invert } // NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header @@ -320,7 +340,7 @@ type Matcher interface { // used for suspicious packets. // // Precondition: packet.NetworkHeader is set. - Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) + Match(hook Hook, packet *PacketBuffer, inputInterfaceName, outputInterfaceName string) (matches bool, hotdrop bool) } // A Target is the interface for taking an action for a packet. diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go index e22bf506d..271e0f1f4 100644 --- a/pkg/tcpip/stack/stack_state_autogen.go +++ b/pkg/tcpip/stack/stack_state_autogen.go @@ -298,6 +298,9 @@ func (fl *IPHeaderFilter) StateFields() []string { "Src", "SrcMask", "SrcInvert", + "InputInterface", + "InputInterfaceMask", + "InputInterfaceInvert", "OutputInterface", "OutputInterfaceMask", "OutputInterfaceInvert", @@ -316,9 +319,12 @@ func (fl *IPHeaderFilter) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(5, &fl.Src) stateSinkObject.Save(6, &fl.SrcMask) stateSinkObject.Save(7, &fl.SrcInvert) - stateSinkObject.Save(8, &fl.OutputInterface) - stateSinkObject.Save(9, &fl.OutputInterfaceMask) - stateSinkObject.Save(10, &fl.OutputInterfaceInvert) + stateSinkObject.Save(8, &fl.InputInterface) + stateSinkObject.Save(9, &fl.InputInterfaceMask) + stateSinkObject.Save(10, &fl.InputInterfaceInvert) + stateSinkObject.Save(11, &fl.OutputInterface) + stateSinkObject.Save(12, &fl.OutputInterfaceMask) + stateSinkObject.Save(13, &fl.OutputInterfaceInvert) } func (fl *IPHeaderFilter) afterLoad() {} @@ -332,9 +338,12 @@ func (fl *IPHeaderFilter) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(5, &fl.Src) stateSourceObject.Load(6, &fl.SrcMask) stateSourceObject.Load(7, &fl.SrcInvert) - stateSourceObject.Load(8, &fl.OutputInterface) - stateSourceObject.Load(9, &fl.OutputInterfaceMask) - stateSourceObject.Load(10, &fl.OutputInterfaceInvert) + stateSourceObject.Load(8, &fl.InputInterface) + stateSourceObject.Load(9, &fl.InputInterfaceMask) + stateSourceObject.Load(10, &fl.InputInterfaceInvert) + stateSourceObject.Load(11, &fl.OutputInterface) + stateSourceObject.Load(12, &fl.OutputInterfaceMask) + stateSourceObject.Load(13, &fl.OutputInterfaceInvert) } func (l *linkAddrEntryList) StateTypeName() string { |