summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rwxr-xr-xpkg/abi/linux/linux_abi_autogen_unsafe.go10
-rw-r--r--pkg/abi/linux/netfilter.go2
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go40
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go8
-rwxr-xr-xpkg/tcpip/stack/iptables.go42
-rwxr-xr-xpkg/tcpip/stack/iptables_types.go13
-rw-r--r--pkg/tcpip/stack/nic.go2
-rw-r--r--pkg/tcpip/stack/stack.go16
8 files changed, 105 insertions, 28 deletions
diff --git a/pkg/abi/linux/linux_abi_autogen_unsafe.go b/pkg/abi/linux/linux_abi_autogen_unsafe.go
index e9ce6e6a7..6d0ae0409 100755
--- a/pkg/abi/linux/linux_abi_autogen_unsafe.go
+++ b/pkg/abi/linux/linux_abi_autogen_unsafe.go
@@ -124,12 +124,12 @@ func (s *Statx) UnmarshalBytes(src []byte) {
// Packed implements marshal.Marshallable.Packed.
//go:nosplit
func (s *Statx) Packed() bool {
- return s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed()
+ return s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed()
}
// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
func (s *Statx) MarshalUnsafe(dst []byte) {
- if s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() {
+ if s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() {
safecopy.CopyIn(dst, unsafe.Pointer(s))
} else {
s.MarshalBytes(dst)
@@ -138,7 +138,7 @@ func (s *Statx) MarshalUnsafe(dst []byte) {
// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
func (s *Statx) UnmarshalUnsafe(src []byte) {
- if s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() {
+ if s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() {
safecopy.CopyOut(unsafe.Pointer(s), src)
} else {
s.UnmarshalBytes(src)
@@ -178,7 +178,7 @@ func (s *Statx) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
// CopyIn implements marshal.Marshallable.CopyIn.
//go:nosplit
func (s *Statx) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() {
+ if !s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() {
// Type Statx doesn't have a packed layout in memory, fall back to UnmarshalBytes.
buf := task.CopyScratchBuffer(s.SizeBytes()) // escapes: okay.
length, err := task.CopyInBytes(addr, buf) // escapes: okay.
@@ -204,7 +204,7 @@ func (s *Statx) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
// WriteTo implements io.WriterTo.WriteTo.
func (s *Statx) WriteTo(w io.Writer) (int64, error) {
- if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() {
+ if !s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() {
// Type Statx doesn't have a packed layout in memory, fall back to MarshalBytes.
buf := make([]byte, s.SizeBytes())
s.MarshalBytes(buf)
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index a8d4f9d69..46d8b0b42 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -146,7 +146,7 @@ type IPTIP struct {
// OutputInterface is the output network interface.
OutputInterface [IFNAMSIZ]byte
- // InputInterfaceMask is the intput interface mask.
+ // InputInterfaceMask is the input interface mask.
InputInterfaceMask [IFNAMSIZ]byte
// OuputInterfaceMask is the output interface mask.
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 40736fb38..41a1ce031 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -17,6 +17,7 @@
package netfilter
import (
+ "bytes"
"errors"
"fmt"
@@ -204,6 +205,16 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI
TargetOffset: linux.SizeOfIPTEntry,
},
}
+ copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst)
+ copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask)
+ copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface)
+ copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
+ if rule.Filter.DstInvert {
+ entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP
+ }
+ if rule.Filter.OutputInterfaceInvert {
+ entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
+ }
for _, matcher := range rule.Matchers {
// Serialize the matcher and add it to the
@@ -719,11 +730,27 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize {
return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask))
}
+
+ n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0)
+ if n == -1 {
+ n = len(iptip.OutputInterface)
+ }
+ ifname := string(iptip.OutputInterface[:n])
+
+ n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0)
+ if n == -1 {
+ n = len(iptip.OutputInterfaceMask)
+ }
+ ifnameMask := string(iptip.OutputInterfaceMask[:n])
+
return stack.IPHeaderFilter{
- Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
- Dst: tcpip.Address(iptip.Dst[:]),
- DstMask: tcpip.Address(iptip.DstMask[:]),
- DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0,
+ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ Dst: tcpip.Address(iptip.Dst[:]),
+ DstMask: tcpip.Address(iptip.DstMask[:]),
+ DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0,
+ OutputInterface: ifname,
+ OutputInterfaceMask: ifnameMask,
+ OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0,
}, nil
}
@@ -732,16 +759,15 @@ func containsUnsupportedFields(iptip linux.IPTIP) bool {
// - Protocol
// - Dst and DstMask
// - The inverse destination IP check flag
+ // - OutputInterface, OutputInterfaceMask and its inverse.
var emptyInetAddr = linux.InetAddr{}
var emptyInterface = [linux.IFNAMSIZ]byte{}
// Disable any supported inverse flags.
- inverseMask := uint8(linux.IPT_INV_DSTIP)
+ inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_VIA_OUT)
return iptip.Src != emptyInetAddr ||
iptip.SrcMask != emptyInetAddr ||
iptip.InputInterface != emptyInterface ||
- iptip.OutputInterface != emptyInterface ||
iptip.InputInterfaceMask != emptyInterface ||
- iptip.OutputInterfaceMask != emptyInterface ||
iptip.Flags != 0 ||
iptip.InverseFlags&^inverseMask != 0
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 9db42b2a4..64046cbbf 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -249,10 +249,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
+ nicName := e.stack.FindNICNameFromID(e.NICID())
// iptables filtering. All packets that reach here are locally
// generated.
ipt := e.stack.IPTables()
- if ok := ipt.Check(stack.Output, &pkt, gso, r, ""); !ok {
+ if ok := ipt.Check(stack.Output, &pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
return nil
}
@@ -319,10 +320,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
pkt = pkt.Next()
}
+ nicName := e.stack.FindNICNameFromID(e.NICID())
// iptables filtering. All packets that reach here are locally
// generated.
ipt := e.stack.IPTables()
- dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r)
+ dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
@@ -445,7 +447,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
ipt := e.stack.IPTables()
- if ok := ipt.Check(stack.Input, &pkt, nil, nil, ""); !ok {
+ if ok := ipt.Check(stack.Input, &pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
return
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 7c3c47d50..443423b3c 100755
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -16,6 +16,7 @@ package stack
import (
"fmt"
+ "strings"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -178,7 +179,7 @@ const (
// dropped.
//
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address) bool {
+func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool {
// Packets are manipulated only if connection and matching
// NAT rule exists.
it.connections.HandlePacket(pkt, hook, gso, r)
@@ -187,7 +188,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
for _, tablename := range it.Priorities[hook] {
table := it.Tables[tablename]
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -228,10 +229,10 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
-func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if !pkt.NatDone {
- if ok := it.Check(hook, pkt, gso, r, ""); !ok {
+ if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok {
if drop == nil {
drop = make(map[*PacketBuffer]struct{})
}
@@ -251,11 +252,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *
// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
// precondition.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
case RuleAccept:
return chainAccept
@@ -272,7 +273,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -298,7 +299,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
// precondition.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// If pkt.NetworkHeader hasn't been set yet, it will be contained in
@@ -313,7 +314,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
}
// Check whether the packet matches the IP header filter.
- if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader)) {
+ if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader), hook, nicName) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
@@ -335,7 +336,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
return rule.Target.Action(pkt, &it.connections, hook, gso, r, address)
}
-func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool {
+func filterMatch(filter IPHeaderFilter, hdr header.IPv4, hook Hook, nicName string) bool {
// TODO(gvisor.dev/issue/170): Support other fields of the filter.
// Check the transport protocol.
if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() {
@@ -355,5 +356,26 @@ func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool {
return false
}
+ // Check the output interface.
+ // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
+ // hooks after supported.
+ if hook == Output {
+ n := len(filter.OutputInterface)
+ if n == 0 {
+ return true
+ }
+
+ // If the interface name ends with '+', any interface which begins
+ // with the name should be matched.
+ ifName := filter.OutputInterface
+ matches = true
+ if strings.HasSuffix(ifName, "+") {
+ matches = strings.HasPrefix(nicName, ifName[:n-1])
+ } else {
+ matches = nicName == ifName
+ }
+ return filter.OutputInterfaceInvert != matches
+ }
+
return true
}
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 1bb0ba1bd..fe06007ae 100755
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -158,6 +158,19 @@ type IPHeaderFilter struct {
// true the filter will match packets that fail the destination
// comparison.
DstInvert bool
+
+ // OutputInterface matches the name of the outgoing interface for the
+ // packet.
+ OutputInterface string
+
+ // OutputInterfaceMask masks the characters of the interface name when
+ // comparing with OutputInterface.
+ OutputInterfaceMask string
+
+ // OutputInterfaceInvert inverts the meaning of outgoing interface check,
+ // i.e. when true the filter will match packets that fail the outgoing
+ // interface comparison.
+ OutputInterfaceInvert bool
}
// A Matcher is the interface for matching packets.
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 8f4c1fe42..54103fdb3 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1233,7 +1233,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
// iptables filtering.
ipt := n.stack.IPTables()
address := n.primaryAddress(protocol)
- if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address); !ok {
+ if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address, ""); !ok {
// iptables is telling us to drop the packet.
return
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index e33fae4eb..b39ffa9fb 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1898,9 +1898,23 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres
nic.mu.RLock()
defer nic.mu.RUnlock()
- // An endpoint with this id exists, check if it can be used and return it.
+ // An endpoint with this id exists, check if it can be
+ // used and return it.
return ref.ep, nil
}
}
return nil, tcpip.ErrBadAddress
}
+
+// FindNICNameFromID returns the name of the nic for the given NICID.
+func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return ""
+ }
+
+ return nic.Name()
+}