summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go30
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go31
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go9
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go2
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go2
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go2
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go14
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go14
-rw-r--r--pkg/tcpip/stack/iptables.go24
-rw-r--r--pkg/tcpip/stack/iptables_types.go64
-rw-r--r--pkg/tcpip/stack/stack_state_autogen.go21
11 files changed, 117 insertions, 96 deletions
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index 70c561cce..2f913787b 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -15,7 +15,6 @@
package netfilter
import (
- "bytes"
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -220,18 +219,6 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask))
}
- n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0)
- if n == -1 {
- n = len(iptip.OutputInterface)
- }
- ifname := string(iptip.OutputInterface[:n])
-
- n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0)
- if n == -1 {
- n = len(iptip.OutputInterfaceMask)
- }
- ifnameMask := string(iptip.OutputInterfaceMask[:n])
-
return stack.IPHeaderFilter{
Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
// A Protocol value of 0 indicates all protocols match.
@@ -242,8 +229,11 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
Src: tcpip.Address(iptip.Src[:]),
SrcMask: tcpip.Address(iptip.SrcMask[:]),
SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0,
- OutputInterface: ifname,
- OutputInterfaceMask: ifnameMask,
+ InputInterface: string(trimNullBytes(iptip.InputInterface[:])),
+ InputInterfaceMask: string(trimNullBytes(iptip.InputInterfaceMask[:])),
+ InputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_IN != 0,
+ OutputInterface: string(trimNullBytes(iptip.OutputInterface[:])),
+ OutputInterfaceMask: string(trimNullBytes(iptip.OutputInterfaceMask[:])),
OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0,
}, nil
}
@@ -254,12 +244,12 @@ func containsUnsupportedFields4(iptip linux.IPTIP) bool {
// - Dst and DstMask
// - Src and SrcMask
// - The inverse destination IP check flag
+ // - InputInterface, InputInterfaceMask and its inverse.
// - OutputInterface, OutputInterfaceMask and its inverse.
- var emptyInterface = [linux.IFNAMSIZ]byte{}
+ const flagMask = 0
// Disable any supported inverse flags.
- inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT)
- return iptip.InputInterface != emptyInterface ||
- iptip.InputInterfaceMask != emptyInterface ||
- iptip.Flags != 0 ||
+ const inverseMask = linux.IPT_INV_DSTIP | linux.IPT_INV_SRCIP |
+ linux.IPT_INV_VIA_IN | linux.IPT_INV_VIA_OUT
+ return iptip.Flags&^flagMask != 0 ||
iptip.InverseFlags&^inverseMask != 0
}
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
index 5dbb604f0..263d9d3b5 100644
--- a/pkg/sentry/socket/netfilter/ipv6.go
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -15,7 +15,6 @@
package netfilter
import (
- "bytes"
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -223,18 +222,6 @@ func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) {
return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask))
}
- n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0)
- if n == -1 {
- n = len(iptip.OutputInterface)
- }
- ifname := string(iptip.OutputInterface[:n])
-
- n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0)
- if n == -1 {
- n = len(iptip.OutputInterfaceMask)
- }
- ifnameMask := string(iptip.OutputInterfaceMask[:n])
-
return stack.IPHeaderFilter{
Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
// In ip6tables a flag controls whether to check the protocol.
@@ -245,8 +232,11 @@ func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) {
Src: tcpip.Address(iptip.Src[:]),
SrcMask: tcpip.Address(iptip.SrcMask[:]),
SrcInvert: iptip.InverseFlags&linux.IP6T_INV_SRCIP != 0,
- OutputInterface: ifname,
- OutputInterfaceMask: ifnameMask,
+ InputInterface: string(trimNullBytes(iptip.InputInterface[:])),
+ InputInterfaceMask: string(trimNullBytes(iptip.InputInterfaceMask[:])),
+ InputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_IN != 0,
+ OutputInterface: string(trimNullBytes(iptip.OutputInterface[:])),
+ OutputInterfaceMask: string(trimNullBytes(iptip.OutputInterfaceMask[:])),
OutputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_OUT != 0,
}, nil
}
@@ -257,14 +247,13 @@ func containsUnsupportedFields6(iptip linux.IP6TIP) bool {
// - Dst and DstMask
// - Src and SrcMask
// - The inverse destination IP check flag
+ // - InputInterface, InputInterfaceMask and its inverse.
// - OutputInterface, OutputInterfaceMask and its inverse.
- var emptyInterface = [linux.IFNAMSIZ]byte{}
- flagMask := uint8(linux.IP6T_F_PROTO)
+ const flagMask = linux.IP6T_F_PROTO
// Disable any supported inverse flags.
- inverseMask := uint8(linux.IP6T_INV_DSTIP) | uint8(linux.IP6T_INV_SRCIP) | uint8(linux.IP6T_INV_VIA_OUT)
- return iptip.InputInterface != emptyInterface ||
- iptip.InputInterfaceMask != emptyInterface ||
- iptip.Flags&^flagMask != 0 ||
+ const inverseMask = linux.IP6T_INV_DSTIP | linux.IP6T_INV_SRCIP |
+ linux.IP6T_INV_VIA_IN | linux.IP6T_INV_VIA_OUT
+ return iptip.Flags&^flagMask != 0 ||
iptip.InverseFlags&^inverseMask != 0 ||
iptip.TOS != 0
}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 26bd1abd4..7ae18b2a3 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -17,6 +17,7 @@
package netfilter
import (
+ "bytes"
"errors"
"fmt"
@@ -393,3 +394,11 @@ func TargetRevision(t *kernel.Task, revPtr usermem.Addr, netProto tcpip.NetworkP
rev.Revision = maxSupported
return rev, nil
}
+
+func trimNullBytes(b []byte) []byte {
+ n := bytes.IndexByte(b, 0)
+ if n == -1 {
+ n = len(b)
+ }
+ return b[:n]
+}
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
index 69d13745e..176fa6116 100644
--- a/pkg/sentry/socket/netfilter/owner_matcher.go
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -112,7 +112,7 @@ func (*OwnerMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
+func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) {
// Support only for OUTPUT chain.
// TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also.
if hook != stack.Output {
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 352c51390..2740697b3 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -96,7 +96,7 @@ func (*TCPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
+func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) {
// TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
// into the stack.Check codepath as matchers are added.
switch pkt.NetworkProtocolNumber {
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index c88d8268d..466d5395d 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -93,7 +93,7 @@ func (*UDPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
+func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) {
// TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
// into the stack.Check codepath as matchers are added.
switch pkt.NetworkProtocolNumber {
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index f295a9192..2ccfa0822 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -321,8 +321,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesOutputDropped.Increment()
return nil
@@ -436,10 +436,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
}
- nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName)
+ dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
@@ -611,7 +611,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Loopback traffic skips the prerouting chain.
if !e.nic.IsLoopback() {
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
+ inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesPreroutingDropped.Increment()
return
@@ -698,7 +699,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.ip.IPTablesInputDropped.Increment()
return
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 37884505e..40176594e 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -631,8 +631,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesOutputDropped.Increment()
return nil
@@ -747,8 +747,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName)
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
@@ -897,7 +897,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Loopback traffic skips the prerouting chain.
if !e.nic.IsLoopback() {
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
+ inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesPreroutingDropped.Increment()
return
@@ -955,7 +956,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesInputDropped.Increment()
return
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 09c7811fa..04af933a6 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -267,11 +267,11 @@ const (
// dropped.
//
// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from
-// which address and nicName can be gathered. Currently, address is only
-// needed for prerouting and nicName is only needed for output.
+// which address can be gathered. Currently, address is only needed for
+// prerouting.
//
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool {
+func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
return true
}
@@ -302,7 +302,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer
table = it.v4Tables[tableID]
}
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -385,10 +385,10 @@ func (it *IPTables) startReaper(interval time.Duration) {
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
-func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if !pkt.NatDone {
- if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok {
+ if ok := it.Check(hook, pkt, gso, r, "", inNicName, outNicName); !ok {
if drop == nil {
drop = make(map[*PacketBuffer]struct{})
}
@@ -408,11 +408,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict {
case RuleAccept:
return chainAccept
@@ -429,7 +429,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, inNicName, outNicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -455,11 +455,11 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
- if !rule.Filter.match(pkt, hook, nicName) {
+ if !rule.Filter.match(pkt, hook, inNicName, outNicName) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
@@ -467,7 +467,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// Go through each rule matcher. If they all match, run
// the rule target.
for _, matcher := range rule.Matchers {
- matches, hotdrop := matcher.Match(hook, pkt, "")
+ matches, hotdrop := matcher.Match(hook, pkt, inNicName, outNicName)
if hotdrop {
return RuleDrop, 0
}
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 56a3e7861..fd9d61e39 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -210,8 +210,19 @@ type IPHeaderFilter struct {
// filter will match packets that fail the source comparison.
SrcInvert bool
- // OutputInterface matches the name of the outgoing interface for the
- // packet.
+ // InputInterface matches the name of the incoming interface for the packet.
+ InputInterface string
+
+ // InputInterfaceMask masks the characters of the interface name when
+ // comparing with InputInterface.
+ InputInterfaceMask string
+
+ // InputInterfaceInvert inverts the meaning of incoming interface check,
+ // i.e. when true the filter will match packets that fail the incoming
+ // interface comparison.
+ InputInterfaceInvert bool
+
+ // OutputInterface matches the name of the outgoing interface for the packet.
OutputInterface string
// OutputInterfaceMask masks the characters of the interface name when
@@ -228,7 +239,7 @@ type IPHeaderFilter struct {
//
// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4
// or IPv6 header length.
-func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool {
+func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool {
// Extract header fields.
var (
// TODO(gvisor.dev/issue/170): Support other filter fields.
@@ -264,26 +275,35 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo
return false
}
- // Check the output interface.
- // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
- // hooks after supported.
- if hook == Output {
- n := len(fl.OutputInterface)
- if n == 0 {
- return true
- }
-
- // If the interface name ends with '+', any interface which
- // begins with the name should be matched.
- ifName := fl.OutputInterface
- matches := nicName == ifName
- if strings.HasSuffix(ifName, "+") {
- matches = strings.HasPrefix(nicName, ifName[:n-1])
- }
- return fl.OutputInterfaceInvert != matches
+ switch hook {
+ case Prerouting, Input:
+ return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert)
+ case Output:
+ return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert)
+ case Forward, Postrouting:
+ // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
+ // hooks after supported.
+ return true
+ default:
+ panic(fmt.Sprintf("unknown hook: %d", hook))
}
+}
- return true
+func matchIfName(nicName string, ifName string, invert bool) bool {
+ n := len(ifName)
+ if n == 0 {
+ // If the interface name is omitted in the filter, any interface will match.
+ return true
+ }
+ // If the interface name ends with '+', any interface which begins with the
+ // name should be matched.
+ var matches bool
+ if strings.HasSuffix(ifName, "+") {
+ matches = strings.HasPrefix(nicName, ifName[:n-1])
+ } else {
+ matches = nicName == ifName
+ }
+ return matches != invert
}
// NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header
@@ -320,7 +340,7 @@ type Matcher interface {
// used for suspicious packets.
//
// Precondition: packet.NetworkHeader is set.
- Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
+ Match(hook Hook, packet *PacketBuffer, inputInterfaceName, outputInterfaceName string) (matches bool, hotdrop bool)
}
// A Target is the interface for taking an action for a packet.
diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go
index e22bf506d..271e0f1f4 100644
--- a/pkg/tcpip/stack/stack_state_autogen.go
+++ b/pkg/tcpip/stack/stack_state_autogen.go
@@ -298,6 +298,9 @@ func (fl *IPHeaderFilter) StateFields() []string {
"Src",
"SrcMask",
"SrcInvert",
+ "InputInterface",
+ "InputInterfaceMask",
+ "InputInterfaceInvert",
"OutputInterface",
"OutputInterfaceMask",
"OutputInterfaceInvert",
@@ -316,9 +319,12 @@ func (fl *IPHeaderFilter) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(5, &fl.Src)
stateSinkObject.Save(6, &fl.SrcMask)
stateSinkObject.Save(7, &fl.SrcInvert)
- stateSinkObject.Save(8, &fl.OutputInterface)
- stateSinkObject.Save(9, &fl.OutputInterfaceMask)
- stateSinkObject.Save(10, &fl.OutputInterfaceInvert)
+ stateSinkObject.Save(8, &fl.InputInterface)
+ stateSinkObject.Save(9, &fl.InputInterfaceMask)
+ stateSinkObject.Save(10, &fl.InputInterfaceInvert)
+ stateSinkObject.Save(11, &fl.OutputInterface)
+ stateSinkObject.Save(12, &fl.OutputInterfaceMask)
+ stateSinkObject.Save(13, &fl.OutputInterfaceInvert)
}
func (fl *IPHeaderFilter) afterLoad() {}
@@ -332,9 +338,12 @@ func (fl *IPHeaderFilter) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(5, &fl.Src)
stateSourceObject.Load(6, &fl.SrcMask)
stateSourceObject.Load(7, &fl.SrcInvert)
- stateSourceObject.Load(8, &fl.OutputInterface)
- stateSourceObject.Load(9, &fl.OutputInterfaceMask)
- stateSourceObject.Load(10, &fl.OutputInterfaceInvert)
+ stateSourceObject.Load(8, &fl.InputInterface)
+ stateSourceObject.Load(9, &fl.InputInterfaceMask)
+ stateSourceObject.Load(10, &fl.InputInterfaceInvert)
+ stateSourceObject.Load(11, &fl.OutputInterface)
+ stateSourceObject.Load(12, &fl.OutputInterfaceMask)
+ stateSourceObject.Load(13, &fl.OutputInterfaceInvert)
}
func (l *linkAddrEntryList) StateTypeName() string {