summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/abi/linux/netfilter.go11
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go6
-rw-r--r--pkg/sentry/socket/netfilter/targets.go188
-rw-r--r--pkg/tcpip/network/internal/ip/stats.go46
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go26
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go108
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go26
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go109
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/conntrack.go234
-rw-r--r--pkg/tcpip/stack/hook_string.go41
-rw-r--r--pkg/tcpip/stack/iptables.go7
-rw-r--r--pkg/tcpip/stack/iptables_targets.go78
-rw-r--r--pkg/tcpip/stack/packet_buffer.go11
-rw-r--r--pkg/tcpip/tcpip.go4
-rw-r--r--test/iptables/iptables_test.go8
-rw-r--r--test/iptables/iptables_util.go61
-rw-r--r--test/iptables/nat.go122
18 files changed, 850 insertions, 237 deletions
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 378f1baf3..775bbc759 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -375,6 +375,17 @@ type XTRedirectTarget struct {
// SizeOfXTRedirectTarget is the size of an XTRedirectTarget.
const SizeOfXTRedirectTarget = 56
+// XTSNATTarget triggers Source NAT when reached.
+// Adding 4 bytes of padding to make the struct 8 byte aligned.
+type XTSNATTarget struct {
+ Target XTEntryTarget
+ NfRange NfNATIPV4MultiRangeCompat
+ _ [4]byte
+}
+
+// SizeOfXTSNATTarget is the size of an XTSNATTarget.
+const SizeOfXTSNATTarget = 56
+
// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
//
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 5200e08ed..c6fa3fd16 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -274,10 +274,10 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
}
// TODO(gvisor.dev/issue/170): Support other chains.
- // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
- // make sure all other chains point to ACCEPT rules.
+ // Since we don't support FORWARD, yet, make sure all other chains point to
+ // ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
- if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting {
+ if hook := stack.Hook(hook); hook == stack.Forward {
if ruleIdx == stack.HookUnset {
continue
}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index 80f8c6430..38b6491e2 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -35,6 +35,11 @@ const ErrorTargetName = "ERROR"
// change the destination port and/or IP for packets.
const RedirectTargetName = "REDIRECT"
+// SNATTargetName is used to mark targets as SNAT targets. SNAT targets should
+// be reached for only NAT table. These targets will change the source port
+// and/or IP for packets.
+const SNATTargetName = "SNAT"
+
func init() {
// Standard targets include ACCEPT, DROP, RETURN, and JUMP.
registerTargetMaker(&standardTargetMaker{
@@ -59,6 +64,13 @@ func init() {
registerTargetMaker(&nfNATTargetMaker{
NetworkProtocol: header.IPv6ProtocolNumber,
})
+
+ registerTargetMaker(&snatTargetMakerV4{
+ NetworkProtocol: header.IPv4ProtocolNumber,
+ })
+ registerTargetMaker(&snatTargetMakerV6{
+ NetworkProtocol: header.IPv6ProtocolNumber,
+ })
}
// The stack package provides some basic, useful targets for us. The following
@@ -131,6 +143,17 @@ func (rt *redirectTarget) id() targetID {
}
}
+type snatTarget struct {
+ stack.SNATTarget
+}
+
+func (st *snatTarget) id() targetID {
+ return targetID{
+ name: SNATTargetName,
+ networkProtocol: st.NetworkProtocol,
+ }
+}
+
type standardTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
@@ -341,7 +364,7 @@ type nfNATTarget struct {
Range linux.NFNATRange
}
-const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange
+const nfNATMarshalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange
type nfNATTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
@@ -358,7 +381,7 @@ func (*nfNATTargetMaker) marshal(target target) []byte {
rt := target.(*redirectTarget)
nt := nfNATTarget{
Target: linux.XTEntryTarget{
- TargetSize: nfNATMarhsalledSize,
+ TargetSize: nfNATMarshalledSize,
},
Range: linux.NFNATRange{
Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED,
@@ -371,12 +394,12 @@ func (*nfNATTargetMaker) marshal(target target) []byte {
nt.Range.MinProto = htons(rt.Port)
nt.Range.MaxProto = nt.Range.MinProto
- ret := make([]byte, 0, nfNATMarhsalledSize)
+ ret := make([]byte, 0, nfNATMarshalledSize)
return binary.Marshal(ret, hostarch.ByteOrder, nt)
}
func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
- if size := nfNATMarhsalledSize; len(buf) < size {
+ if size := nfNATMarshalledSize; len(buf) < size {
nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size)
return nil, syserr.ErrInvalidArgument
}
@@ -387,7 +410,7 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
}
var natRange linux.NFNATRange
- buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize]
+ buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
binary.Unmarshal(buf, hostarch.ByteOrder, &natRange)
// We don't support port or address ranges.
@@ -418,6 +441,161 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
return &target, nil
}
+type snatTargetMakerV4 struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+func (st *snatTargetMakerV4) id() targetID {
+ return targetID{
+ name: SNATTargetName,
+ networkProtocol: st.NetworkProtocol,
+ }
+}
+
+func (*snatTargetMakerV4) marshal(target target) []byte {
+ st := target.(*snatTarget)
+ // This is a snat target named snat.
+ xt := linux.XTSNATTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTSNATTarget,
+ },
+ }
+ copy(xt.Target.Name[:], SNATTargetName)
+
+ xt.NfRange.RangeSize = 1
+ xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED
+ xt.NfRange.RangeIPV4.MinPort = htons(st.Port)
+ xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort
+ copy(xt.NfRange.RangeIPV4.MinIP[:], st.Addr)
+ copy(xt.NfRange.RangeIPV4.MaxIP[:], st.Addr)
+ ret := make([]byte, 0, linux.SizeOfXTSNATTarget)
+ return binary.Marshal(ret, hostarch.ByteOrder, xt)
+}
+
+func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
+ if len(buf) < linux.SizeOfXTSNATTarget {
+ nflog("snatTargetMakerV4: buf has insufficient size for snat target %d", len(buf))
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
+ nflog("snatTargetMakerV4: bad proto %d", p)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var st linux.XTSNATTarget
+ buf = buf[:linux.SizeOfXTSNATTarget]
+ binary.Unmarshal(buf, hostarch.ByteOrder, &st)
+
+ // Copy linux.XTSNATTarget to stack.SNATTarget.
+ target := snatTarget{SNATTarget: stack.SNATTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ }}
+
+ // RangeSize should be 1.
+ nfRange := st.NfRange
+ if nfRange.RangeSize != 1 {
+ nflog("snatTargetMakerV4: bad rangesize %d", nfRange.RangeSize)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/5772): If the rule doesn't specify the source port,
+ // choose one automatically.
+ if nfRange.RangeIPV4.MinPort == 0 {
+ nflog("snatTargetMakerV4: snat target needs to specify a non-zero port")
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): Port range is not supported yet.
+ if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
+ nflog("snatTargetMakerV4: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
+ return nil, syserr.ErrInvalidArgument
+ }
+ if nfRange.RangeIPV4.MinIP != nfRange.RangeIPV4.MaxIP {
+ nflog("snatTargetMakerV4: MinIP != MaxIP (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.Port = ntohs(nfRange.RangeIPV4.MinPort)
+
+ return &target, nil
+}
+
+type snatTargetMakerV6 struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+func (st *snatTargetMakerV6) id() targetID {
+ return targetID{
+ name: SNATTargetName,
+ networkProtocol: st.NetworkProtocol,
+ revision: 1,
+ }
+}
+
+func (*snatTargetMakerV6) marshal(target target) []byte {
+ st := target.(*snatTarget)
+ nt := nfNATTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: nfNATMarshalledSize,
+ },
+ Range: linux.NFNATRange{
+ Flags: linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED,
+ },
+ }
+ copy(nt.Target.Name[:], SNATTargetName)
+ copy(nt.Range.MinAddr[:], st.Addr)
+ copy(nt.Range.MaxAddr[:], st.Addr)
+ nt.Range.MinProto = htons(st.Port)
+ nt.Range.MaxProto = nt.Range.MinProto
+
+ ret := make([]byte, 0, nfNATMarshalledSize)
+ return binary.Marshal(ret, hostarch.ByteOrder, nt)
+}
+
+func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
+ if size := nfNATMarshalledSize; len(buf) < size {
+ nflog("snatTargetMakerV6: buf has insufficient size (%d) for SNAT V6 target (%d)", len(buf), size)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
+ nflog("snatTargetMakerV6: bad proto %d", p)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var natRange linux.NFNATRange
+ buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
+ binary.Unmarshal(buf, hostarch.ByteOrder, &natRange)
+
+ // TODO(gvisor.dev/issue/5689): Support port or address ranges.
+ if natRange.MinAddr != natRange.MaxAddr {
+ nflog("snatTargetMakerV6: MinAddr and MaxAddr are different")
+ return nil, syserr.ErrInvalidArgument
+ }
+ if natRange.MinProto != natRange.MaxProto {
+ nflog("snatTargetMakerV6: MinProto and MaxProto are different")
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/5698): Support other NF_NAT_RANGE flags.
+ if natRange.Flags != linux.NF_NAT_RANGE_MAP_IPS|linux.NF_NAT_RANGE_PROTO_SPECIFIED {
+ nflog("snatTargetMakerV6: invalid range flags %d", natRange.Flags)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ target := snatTarget{
+ SNATTarget: stack.SNATTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ Addr: tcpip.Address(natRange.MinAddr[:]),
+ Port: ntohs(natRange.MinProto),
+ },
+ }
+
+ return &target, nil
+}
+
// translateToStandardTarget translates from the value in a
// linux.XTStandardTarget to an stack.Verdict.
func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) {
diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go
index b6f39ddb1..d06b26309 100644
--- a/pkg/tcpip/network/internal/ip/stats.go
+++ b/pkg/tcpip/network/internal/ip/stats.go
@@ -21,53 +21,56 @@ import "gvisor.dev/gvisor/pkg/tcpip"
// MultiCounterIPStats holds IP statistics, each counter may have several
// versions.
type MultiCounterIPStats struct {
- // PacketsReceived is the total number of IP packets received from the link
- // layer.
+ // PacketsReceived is the number of IP packets received from the link layer.
PacketsReceived tcpip.MultiCounterStat
- // DisabledPacketsReceived is the total number of IP packets received from the
- // link layer when the IP layer is disabled.
+ // DisabledPacketsReceived is the number of IP packets received from the link
+ // layer when the IP layer is disabled.
DisabledPacketsReceived tcpip.MultiCounterStat
- // InvalidDestinationAddressesReceived is the total number of IP packets
- // received with an unknown or invalid destination address.
+ // InvalidDestinationAddressesReceived is the number of IP packets received
+ // with an unknown or invalid destination address.
InvalidDestinationAddressesReceived tcpip.MultiCounterStat
- // InvalidSourceAddressesReceived is the total number of IP packets received
- // with a source address that should never have been received on the wire.
+ // InvalidSourceAddressesReceived is the number of IP packets received with a
+ // source address that should never have been received on the wire.
InvalidSourceAddressesReceived tcpip.MultiCounterStat
- // PacketsDelivered is the total number of incoming IP packets that are
- // successfully delivered to the transport layer.
+ // PacketsDelivered is the number of incoming IP packets that are successfully
+ // delivered to the transport layer.
PacketsDelivered tcpip.MultiCounterStat
- // PacketsSent is the total number of IP packets sent via WritePacket.
+ // PacketsSent is the number of IP packets sent via WritePacket.
PacketsSent tcpip.MultiCounterStat
- // OutgoingPacketErrors is the total number of IP packets which failed to
- // write to a link-layer endpoint.
+ // OutgoingPacketErrors is the number of IP packets which failed to write to a
+ // link-layer endpoint.
OutgoingPacketErrors tcpip.MultiCounterStat
- // MalformedPacketsReceived is the total number of IP Packets that were
- // dropped due to the IP packet header failing validation checks.
+ // MalformedPacketsReceived is the number of IP Packets that were dropped due
+ // to the IP packet header failing validation checks.
MalformedPacketsReceived tcpip.MultiCounterStat
- // MalformedFragmentsReceived is the total number of IP Fragments that were
- // dropped due to the fragment failing validation checks.
+ // MalformedFragmentsReceived is the number of IP Fragments that were dropped
+ // due to the fragment failing validation checks.
MalformedFragmentsReceived tcpip.MultiCounterStat
- // IPTablesPreroutingDropped is the total number of IP packets dropped in the
+ // IPTablesPreroutingDropped is the number of IP packets dropped in the
// Prerouting chain.
IPTablesPreroutingDropped tcpip.MultiCounterStat
- // IPTablesInputDropped is the total number of IP packets dropped in the Input
+ // IPTablesInputDropped is the number of IP packets dropped in the Input
// chain.
IPTablesInputDropped tcpip.MultiCounterStat
- // IPTablesOutputDropped is the total number of IP packets dropped in the
- // Output chain.
+ // IPTablesOutputDropped is the number of IP packets dropped in the Output
+ // chain.
IPTablesOutputDropped tcpip.MultiCounterStat
+ // IPTablesPostroutingDropped is the number of IP packets dropped in the
+ // Postrouting chain.
+ IPTablesPostroutingDropped tcpip.MultiCounterStat
+
// TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out
// of IPStats.
@@ -98,6 +101,7 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped)
m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped)
m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped)
+ m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped)
m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived)
m.OptionRecordRouteReceived.Init(a.OptionRecordRouteReceived, b.OptionRecordRouteReceived)
m.OptionRouterAlertReceived.Init(a.OptionRouterAlertReceived, b.OptionRouterAlertReceived)
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 2e44f8523..9a3dc78cb 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -415,6 +415,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
return nil
}
+ // Postrouting NAT can only change the source address, and does not alter the
+ // route or outgoing interface of the packet.
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesPostroutingDropped.Increment()
+ return nil
+ }
+
stats := e.stats.ip
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
@@ -486,9 +495,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName)
- stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
- for pkt := range dropped {
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
+ stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
+ for pkt := range outputDropped {
pkts.Remove(pkt)
}
@@ -510,6 +519,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
+ // We ignore the list of NAT-ed packets here because Postrouting NAT can only
+ // change the source address, and does not alter the route or outgoing
+ // interface of the packet.
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, gso, r, "" /* inNicName */, outNicName)
+ stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
+ for pkt := range postroutingDropped {
+ pkts.Remove(pkt)
+ }
+
// The rest of the packets can be delivered to the NIC as a batch.
pktsLen := pkts.Len()
written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
@@ -517,7 +535,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
// Dropped packets aren't errors, so include them in the return value.
- return locallyDelivered + written + len(dropped), err
+ return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index eba91c68c..8e6e81005 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -2612,34 +2612,36 @@ func TestWriteStats(t *testing.T) {
const nPackets = 3
tests := []struct {
- name string
- setup func(*testing.T, *stack.Stack)
- allowPackets int
- expectSent int
- expectDropped int
- expectWritten int
+ name string
+ setup func(*testing.T, *stack.Stack)
+ allowPackets int
+ expectSent int
+ expectOutputDropped int
+ expectPostroutingDropped int
+ expectWritten int
}{
{
name: "Accept all",
// No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: math.MaxInt32,
- expectSent: nPackets,
- expectDropped: 0,
- expectWritten: nPackets,
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
}, {
name: "Accept all with error",
// No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: nPackets - 1,
- expectSent: nPackets - 1,
- expectDropped: 0,
- expectWritten: nPackets - 1,
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: nPackets - 1,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets - 1,
}, {
- name: "Drop all",
+ name: "Drop all with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule.
- t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Output]
@@ -2648,16 +2650,32 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %s", err)
}
},
- allowPackets: math.MaxInt32,
- expectSent: 0,
- expectDropped: nPackets,
- expectWritten: nPackets,
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectOutputDropped: nPackets,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
}, {
- name: "Drop some",
+ name: "Drop all with Postrouting chain",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ ipt := stk.IPTables()
+ filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %s", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule that matches only 1
// of the 3 packets.
- t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
// We'll match and DROP the last packet.
@@ -2670,10 +2688,33 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %s", err)
}
},
- allowPackets: math.MaxInt32,
- expectSent: nPackets - 1,
- expectDropped: 1,
- expectWritten: nPackets,
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 1,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some with Postrouting chain",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Postrouting DROP rule that matches only 1
+ // of the 3 packets.
+ ipt := stk.IPTables()
+ filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %s", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 1,
+ expectWritten: nPackets,
},
}
@@ -2724,13 +2765,16 @@ func TestWriteStats(t *testing.T) {
nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
- t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent)
+ }
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped {
+ t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped)
}
- if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
- t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped {
+ t.Errorf("got rt.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped)
}
if nWritten != test.expectWritten {
- t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten)
}
})
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index d36cefcd0..2e515379c 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -769,6 +769,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
return nil
}
+ // Postrouting NAT can only change the source address, and does not alter the
+ // route or outgoing interface of the packet.
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesPostroutingDropped.Increment()
+ return nil
+ }
+
stats := e.stats.ip
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
@@ -840,9 +849,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
- stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
- for pkt := range dropped {
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
+ stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
+ for pkt := range outputDropped {
pkts.Remove(pkt)
}
@@ -863,6 +872,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
locallyDelivered++
}
+ // We ignore the list of NAT-ed packets here because Postrouting NAT can only
+ // change the source address, and does not alter the route or outgoing
+ // interface of the packet.
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, gso, r, "" /* inNicName */, outNicName)
+ stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
+ for pkt := range postroutingDropped {
+ pkts.Remove(pkt)
+ }
+
// The rest of the packets can be delivered to the NIC as a batch.
pktsLen := pkts.Len()
written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
@@ -870,7 +888,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
// Dropped packets aren't errors, so include them in the return value.
- return locallyDelivered + written + len(dropped), err
+ return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index c206cebeb..a620e9ad9 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -2468,34 +2468,36 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
func TestWriteStats(t *testing.T) {
const nPackets = 3
tests := []struct {
- name string
- setup func(*testing.T, *stack.Stack)
- allowPackets int
- expectSent int
- expectDropped int
- expectWritten int
+ name string
+ setup func(*testing.T, *stack.Stack)
+ allowPackets int
+ expectSent int
+ expectOutputDropped int
+ expectPostroutingDropped int
+ expectWritten int
}{
{
name: "Accept all",
// No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: math.MaxInt32,
- expectSent: nPackets,
- expectDropped: 0,
- expectWritten: nPackets,
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
}, {
name: "Accept all with error",
// No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: nPackets - 1,
- expectSent: nPackets - 1,
- expectDropped: 0,
- expectWritten: nPackets - 1,
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: nPackets - 1,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets - 1,
}, {
- name: "Drop all",
+ name: "Drop all with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule.
- t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Output]
@@ -2504,16 +2506,33 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %v", err)
}
},
- allowPackets: math.MaxInt32,
- expectSent: 0,
- expectDropped: nPackets,
- expectWritten: nPackets,
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectOutputDropped: nPackets,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
}, {
- name: "Drop some",
+ name: "Drop all with Postrouting chain",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule.
+ ipt := stk.IPTables()
+ filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule that matches only 1
// of the 3 packets.
- t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
// We'll match and DROP the last packet.
@@ -2526,10 +2545,33 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %v", err)
}
},
- allowPackets: math.MaxInt32,
- expectSent: nPackets - 1,
- expectDropped: 1,
- expectWritten: nPackets,
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 1,
+ expectPostroutingDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some with Postrouting chain",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Postrouting DROP rule that matches only 1
+ // of the 3 packets.
+ ipt := stk.IPTables()
+ filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectOutputDropped: 0,
+ expectPostroutingDropped: 1,
+ expectWritten: nPackets,
},
}
@@ -2578,13 +2620,16 @@ func TestWriteStats(t *testing.T) {
nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
- t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent)
+ }
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped {
+ t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped)
}
- if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
- t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped {
+ t.Errorf("got r.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped)
}
if nWritten != test.expectWritten {
- t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten)
}
})
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 49362333a..bbd75c73a 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -45,6 +45,7 @@ go_library(
"addressable_endpoint_state.go",
"conntrack.go",
"headertype_string.go",
+ "hook_string.go",
"icmp_rate_limit.go",
"iptables.go",
"iptables_state.go",
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 3f083928f..41e964cf3 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -16,6 +16,7 @@ package stack
import (
"encoding/binary"
+ "fmt"
"sync"
"time"
@@ -29,7 +30,7 @@ import (
// The connection is created for a packet if it does not exist. Every
// connection contains two tuples (original and reply). The tuples are
// manipulated if there is a matching NAT rule. The packet is modified by
-// looking at the tuples in the Prerouting and Output hooks.
+// looking at the tuples in each hook.
//
// Currently, only TCP tracking is supported.
@@ -46,12 +47,14 @@ const (
)
// Manipulation type for the connection.
+// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and
+// DNAT at the same time.
type manipType int
const (
manipNone manipType = iota
- manipDstPrerouting
- manipDstOutput
+ manipSource
+ manipDestination
)
// tuple holds a connection's identifying and manipulating data in one
@@ -108,6 +111,7 @@ type conn struct {
reply tuple
// manip indicates if the packet should be manipulated. It is immutable.
+ // TODO(gvisor.dev/issue/5696): Support updating manipulation type.
manip manipType
// tcbHook indicates if the packet is inbound or outbound to
@@ -124,6 +128,18 @@ type conn struct {
lastUsed time.Time `state:".(unixTime)"`
}
+// newConn creates new connection.
+func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
+ conn := conn{
+ manip: manip,
+ tcbHook: hook,
+ lastUsed: time.Now(),
+ }
+ conn.original = tuple{conn: &conn, tupleID: orig}
+ conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
+ return &conn
+}
+
// timedOut returns whether the connection timed out based on its state.
func (cn *conn) timedOut(now time.Time) bool {
const establishedTimeout = 5 * 24 * time.Hour
@@ -219,18 +235,6 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
}, nil
}
-// newConn creates new connection.
-func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
- conn := conn{
- manip: manip,
- tcbHook: hook,
- lastUsed: time.Now(),
- }
- conn.original = tuple{conn: &conn, tupleID: orig}
- conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
- return &conn
-}
-
func (ct *ConnTrack) init() {
ct.mu.Lock()
defer ct.mu.Unlock()
@@ -284,20 +288,41 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint1
return nil
}
- // Create a new connection and change the port as per the iptables
- // rule. This tuple will be used to manipulate the packet in
- // handlePacket.
replyTID := tid.reply()
replyTID.srcAddr = address
replyTID.srcPort = port
- var manip manipType
- switch hook {
- case Prerouting:
- manip = manipDstPrerouting
- case Output:
- manip = manipDstOutput
+
+ conn, _ := ct.connForTID(tid)
+ if conn != nil {
+ // The connection is already tracked.
+ // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
+ return nil
}
- conn := newConn(tid, replyTID, manip, hook)
+ conn = newConn(tid, replyTID, manipDestination, hook)
+ ct.insertConn(conn)
+ return conn
+}
+
+func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil
+ }
+ if hook != Input && hook != Postrouting {
+ return nil
+ }
+
+ replyTID := tid.reply()
+ replyTID.dstAddr = address
+ replyTID.dstPort = port
+
+ conn, _ := ct.connForTID(tid)
+ if conn != nil {
+ // The connection is already tracked.
+ // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
+ return nil
+ }
+ conn = newConn(tid, replyTID, manipSource, hook)
ct.insertConn(conn)
return conn
}
@@ -322,6 +347,7 @@ func (ct *ConnTrack) insertConn(conn *conn) {
// Now that we hold the locks, ensure the tuple hasn't been inserted by
// another thread.
+ // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too?
alreadyInserted := false
for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
if other.tupleID == conn.original.tupleID {
@@ -343,86 +369,6 @@ func (ct *ConnTrack) insertConn(conn *conn) {
}
}
-// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
-// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
-func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
- // If this is a noop entry, don't do anything.
- if conn.manip == manipNone {
- return
- }
-
- netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
-
- // For prerouting redirection, packets going in the original direction
- // have their destinations modified and replies have their sources
- // modified.
- switch dir {
- case dirOriginal:
- port := conn.reply.srcPort
- tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.reply.srcAddr)
- case dirReply:
- port := conn.original.dstPort
- tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.original.dstAddr)
- }
-
- // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
- // on inbound packets, so we don't recalculate them. However, we should
- // support cases when they are validated, e.g. when we can't offload
- // receive checksumming.
-
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
-}
-
-// handlePacketOutput manipulates ports for packets in Output hook.
-func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
- // If this is a noop entry, don't do anything.
- if conn.manip == manipNone {
- return
- }
-
- netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
-
- // For output redirection, packets going in the original direction
- // have their destinations modified and replies have their sources
- // modified. For prerouting redirection, we only reach this point
- // when replying, so packet sources are modified.
- if conn.manip == manipDstOutput && dir == dirOriginal {
- port := conn.reply.srcPort
- tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.reply.srcAddr)
- } else {
- port := conn.original.dstPort
- tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.original.dstAddr)
- }
-
- // Calculate the TCP checksum and set it.
- tcpHeader.SetChecksum(0)
- length := uint16(len(tcpHeader) + pkt.Data().Size())
- xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- if gso != nil && gso.NeedsCsum {
- tcpHeader.SetChecksum(xsum)
- } else if r.RequiresTXTransportChecksum() {
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
- }
-
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
-}
-
// handlePacket will manipulate the port and address of the packet if the
// connection exists. Returns whether, after the packet traverses the tables,
// it should create a new entry in the table.
@@ -431,7 +377,9 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
return false
}
- if hook != Prerouting && hook != Output {
+ switch hook {
+ case Prerouting, Input, Output, Postrouting:
+ default:
return false
}
@@ -441,23 +389,79 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
}
conn, dir := ct.connFor(pkt)
- // Connection or Rule not found for the packet.
+ // Connection not found for the packet.
if conn == nil {
- return true
+ // If this is the last hook in the data path for this packet (Input if
+ // incoming, Postrouting if outgoing), indicate that a connection should be
+ // inserted by the end of this hook.
+ return hook == Input || hook == Postrouting
}
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
if len(tcpHeader) < header.TCPMinimumSize {
return false
}
+ // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
+ // validated if checksum offloading is off. It may require IP defrag if the
+ // packets are fragmented.
+
+ switch hook {
+ case Prerouting, Output:
+ if conn.manip == manipDestination {
+ switch dir {
+ case dirOriginal:
+ tcpHeader.SetDestinationPort(conn.reply.srcPort)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
+ case dirReply:
+ tcpHeader.SetSourcePort(conn.original.dstPort)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
+ }
+ pkt.NatDone = true
+ }
+ case Input, Postrouting:
+ if conn.manip == manipSource {
+ switch dir {
+ case dirOriginal:
+ tcpHeader.SetSourcePort(conn.reply.dstPort)
+ netHeader.SetSourceAddress(conn.reply.dstAddr)
+ case dirReply:
+ tcpHeader.SetDestinationPort(conn.original.srcPort)
+ netHeader.SetDestinationAddress(conn.original.srcAddr)
+ }
+ pkt.NatDone = true
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized hook = %s", hook))
+ }
+ if !pkt.NatDone {
+ return false
+ }
+
switch hook {
- case Prerouting:
- handlePacketPrerouting(pkt, conn, dir)
- case Output:
- handlePacketOutput(pkt, conn, gso, r, dir)
+ case Prerouting, Input:
+ case Output, Postrouting:
+ // Calculate the TCP checksum and set it.
+ tcpHeader.SetChecksum(0)
+ length := uint16(len(tcpHeader) + pkt.Data().Size())
+ xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
+ if gso != nil && gso.NeedsCsum {
+ tcpHeader.SetChecksum(xsum)
+ } else if r.RequiresTXTransportChecksum() {
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
+ tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized hook = %s", hook))
+ }
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
- pkt.NatDone = true
// Update the state of tcb.
// TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
@@ -638,8 +642,8 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
if conn == nil {
// Not a tracked connection.
return "", 0, &tcpip.ErrNotConnected{}
- } else if conn.manip == manipNone {
- // Unmanipulated connection.
+ } else if conn.manip != manipDestination {
+ // Unmanipulated destination.
return "", 0, &tcpip.ErrInvalidOptionValue{}
}
diff --git a/pkg/tcpip/stack/hook_string.go b/pkg/tcpip/stack/hook_string.go
new file mode 100644
index 000000000..3dc8a7b02
--- /dev/null
+++ b/pkg/tcpip/stack/hook_string.go
@@ -0,0 +1,41 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type Hook ."; DO NOT EDIT.
+
+package stack
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[Prerouting-0]
+ _ = x[Input-1]
+ _ = x[Forward-2]
+ _ = x[Output-3]
+ _ = x[Postrouting-4]
+ _ = x[NumHooks-5]
+}
+
+const _Hook_name = "PreroutingInputForwardOutputPostroutingNumHooks"
+
+var _Hook_index = [...]uint8{0, 10, 15, 22, 28, 39, 47}
+
+func (i Hook) String() string {
+ if i >= Hook(len(_Hook_index)-1) {
+ return "Hook(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _Hook_name[_Hook_index[i]:_Hook_index[i+1]]
+}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 52890f6eb..7ea87d325 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -175,9 +175,10 @@ func DefaultTables() *IPTables {
},
},
priorities: [NumHooks][]TableID{
- Prerouting: {MangleID, NATID},
- Input: {NATID, FilterID},
- Output: {MangleID, NATID, FilterID},
+ Prerouting: {MangleID, NATID},
+ Input: {NATID, FilterID},
+ Output: {MangleID, NATID, FilterID},
+ Postrouting: {MangleID, NATID},
},
connections: ConnTrack{
seed: generateRandUint32(),
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 0e8b90c9b..317efe754 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -182,3 +182,81 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
return RuleAccept, 0
}
+
+// SNATTarget modifies the source port/IP in the outgoing packets.
+type SNATTarget struct {
+ Addr tcpip.Address
+ Port uint16
+
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// Action implements Target.Action.
+func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+ // Sanity check.
+ if st.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ st.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
+ // Packet is already manipulated.
+ if pkt.NatDone {
+ return RuleAccept, 0
+ }
+
+ // Drop the packet if network and transport header are not set.
+ if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
+ return RuleDrop, 0
+ }
+
+ switch hook {
+ case Postrouting, Input:
+ case Prerouting, Output, Forward:
+ panic(fmt.Sprintf("%s not supported", hook))
+ default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
+ }
+
+ switch protocol := pkt.TransportProtocolNumber; protocol {
+ case header.UDPProtocolNumber:
+ udpHeader := header.UDP(pkt.TransportHeader().View())
+ udpHeader.SetChecksum(0)
+ udpHeader.SetSourcePort(st.Port)
+ netHeader := pkt.Network()
+ netHeader.SetSourceAddress(st.Addr)
+
+ // Only calculate the checksum if offloading isn't supported.
+ if r.RequiresTXTransportChecksum() {
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
+ xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
+ udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
+ }
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
+ pkt.NatDone = true
+ case header.TCPProtocolNumber:
+ if ct == nil {
+ return RuleAccept, 0
+ }
+
+ // Set up conection for matching NAT rule. Only the first
+ // packet of the connection comes here. Other packets will be
+ // manipulated in connection tracking.
+ if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil {
+ ct.handlePacket(pkt, hook, gso, r)
+ }
+ default:
+ return RuleDrop, 0
+ }
+
+ return RuleAccept, 0
+}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 8f288675d..c10304d5f 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -299,9 +299,18 @@ func (pk *PacketBuffer) Network() header.Network {
// See PacketBuffer.Data for details about how a packet buffer holds an inbound
// packet.
func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
- return NewPacketBuffer(PacketBufferOptions{
+ newPk := NewPacketBuffer(PacketBufferOptions{
Data: buffer.NewVectorisedView(pk.Size(), pk.Views()),
})
+ // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
+ // maintain this flag in the packet. Currently conntrack needs this flag to
+ // tell if a noop connection should be inserted at Input hook. Once conntrack
+ // redefines the manipulation field as mutable, we won't need the special noop
+ // connection.
+ if pk.NatDone {
+ newPk.NatDone = true
+ }
+ return newPk
}
// headerInfo stores metadata about a header in a packet.
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 60de16579..2b6e6a89f 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1556,6 +1556,10 @@ type IPStats struct {
// chain.
IPTablesOutputDropped *StatCounter
+ // IPTablesPostroutingDropped is the number of IP packets dropped in the
+ // Postrouting chain.
+ IPTablesPostroutingDropped *StatCounter
+
// TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out
// of IPStats.
// OptionTimestampReceived is the number of Timestamp options seen.
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index d6c69a319..04d112134 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -456,3 +456,11 @@ func TestNATPreRECVORIGDSTADDR(t *testing.T) {
func TestNATOutRECVORIGDSTADDR(t *testing.T) {
singleTest(t, &NATOutRECVORIGDSTADDR{})
}
+
+func TestNATPostSNATUDP(t *testing.T) {
+ singleTest(t, &NATPostSNATUDP{})
+}
+
+func TestNATPostSNATTCP(t *testing.T) {
+ singleTest(t, &NATPostSNATTCP{})
+}
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
index bba17b894..4590e169d 100644
--- a/test/iptables/iptables_util.go
+++ b/test/iptables/iptables_util.go
@@ -69,29 +69,41 @@ func tableRules(ipv6 bool, table string, argsList [][]string) error {
return nil
}
-// listenUDP listens on a UDP port and returns the value of net.Conn.Read() for
-// the first read on that port.
+// listenUDP listens on a UDP port and returns nil if the first read from that
+// port is successful.
func listenUDP(ctx context.Context, port int, ipv6 bool) error {
+ _, err := listenUDPFrom(ctx, port, ipv6)
+ return err
+}
+
+// listenUDPFrom listens on a UDP port and returns the sender's UDP address if
+// the first read from that port is successful.
+func listenUDPFrom(ctx context.Context, port int, ipv6 bool) (*net.UDPAddr, error) {
localAddr := net.UDPAddr{
Port: port,
}
conn, err := net.ListenUDP(udpNetwork(ipv6), &localAddr)
if err != nil {
- return err
+ return nil, err
}
defer conn.Close()
- ch := make(chan error)
+ type result struct {
+ remoteAddr *net.UDPAddr
+ err error
+ }
+
+ ch := make(chan result)
go func() {
- _, err = conn.Read([]byte{0})
- ch <- err
+ _, remoteAddr, err := conn.ReadFromUDP([]byte{0})
+ ch <- result{remoteAddr, err}
}()
select {
- case err := <-ch:
- return err
+ case res := <-ch:
+ return res.remoteAddr, res.err
case <-ctx.Done():
- return ctx.Err()
+ return nil, fmt.Errorf("timed out reading from %s: %w", &localAddr, ctx.Err())
}
}
@@ -125,8 +137,16 @@ func sendUDPLoop(ctx context.Context, ip net.IP, port int, ipv6 bool) error {
}
}
-// listenTCP listens for connections on a TCP port.
+// listenTCP listens for connections on a TCP port, and returns nil if a
+// connection is established.
func listenTCP(ctx context.Context, port int, ipv6 bool) error {
+ _, err := listenTCPFrom(ctx, port, ipv6)
+ return err
+}
+
+// listenTCP listens for connections on a TCP port, and returns the remote
+// TCP address if a connection is established.
+func listenTCPFrom(ctx context.Context, port int, ipv6 bool) (net.Addr, error) {
localAddr := net.TCPAddr{
Port: port,
}
@@ -134,23 +154,32 @@ func listenTCP(ctx context.Context, port int, ipv6 bool) error {
// Starts listening on port.
lConn, err := net.ListenTCP(tcpNetwork(ipv6), &localAddr)
if err != nil {
- return err
+ return nil, err
}
defer lConn.Close()
+ type result struct {
+ remoteAddr net.Addr
+ err error
+ }
+
// Accept connections on port.
- ch := make(chan error)
+ ch := make(chan result)
go func() {
conn, err := lConn.AcceptTCP()
- ch <- err
+ var remoteAddr net.Addr
+ if err == nil {
+ remoteAddr = conn.RemoteAddr()
+ }
+ ch <- result{remoteAddr, err}
conn.Close()
}()
select {
- case err := <-ch:
- return err
+ case res := <-ch:
+ return res.remoteAddr, res.err
case <-ctx.Done():
- return fmt.Errorf("timed out waiting for a connection at %#v: %w", localAddr, ctx.Err())
+ return nil, fmt.Errorf("timed out waiting for a connection at %s: %w", &localAddr, ctx.Err())
}
}
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index 0776639a7..0f25b6a18 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"net"
+ "strconv"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/binary"
@@ -48,6 +49,8 @@ func init() {
RegisterTestCase(&NATOutOriginalDst{})
RegisterTestCase(&NATPreRECVORIGDSTADDR{})
RegisterTestCase(&NATOutRECVORIGDSTADDR{})
+ RegisterTestCase(&NATPostSNATUDP{})
+ RegisterTestCase(&NATPostSNATTCP{})
}
// NATPreRedirectUDPPort tests that packets are redirected to different port.
@@ -486,7 +489,12 @@ func (*NATLoopbackSkipsPrerouting) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (*NATLoopbackSkipsPrerouting) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
// Redirect anything sent to localhost to an unused port.
- dest := []byte{127, 0, 0, 1}
+ var dest net.IP
+ if ipv6 {
+ dest = net.IPv6loopback
+ } else {
+ dest = net.IPv4(127, 0, 0, 1)
+ }
if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil {
return err
}
@@ -915,3 +923,115 @@ func addrMatches6(got unix.RawSockaddrInet6, wantAddrs []net.IP, port uint16) er
}
return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs)
}
+
+const (
+ snatAddrV4 = "194.236.50.155"
+ snatAddrV6 = "2a0a::1"
+ snatPort = 43
+)
+
+// NATPostSNATUDP tests that the source port/IP in the packets are modified as expected.
+type NATPostSNATUDP struct{ localCase }
+
+var _ TestCase = (*NATPostSNATUDP)(nil)
+
+// Name implements TestCase.Name.
+func (*NATPostSNATUDP) Name() string {
+ return "NATPostSNATUDP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (*NATPostSNATUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ var source string
+ if ipv6 {
+ source = fmt.Sprintf("[%s]:%d", snatAddrV6, snatPort)
+ } else {
+ source = fmt.Sprintf("%s:%d", snatAddrV4, snatPort)
+ }
+
+ if err := natTable(ipv6, "-A", "POSTROUTING", "-p", "udp", "-j", "SNAT", "--to-source", source); err != nil {
+ return err
+ }
+ return sendUDPLoop(ctx, ip, acceptPort, ipv6)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (*NATPostSNATUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ remote, err := listenUDPFrom(ctx, acceptPort, ipv6)
+ if err != nil {
+ return err
+ }
+ var snatAddr string
+ if ipv6 {
+ snatAddr = snatAddrV6
+ } else {
+ snatAddr = snatAddrV4
+ }
+ if got, want := remote.IP, net.ParseIP(snatAddr); !got.Equal(want) {
+ return fmt.Errorf("got remote address = %s, want = %s", got, want)
+ }
+ if got, want := remote.Port, snatPort; got != want {
+ return fmt.Errorf("got remote port = %d, want = %d", got, want)
+ }
+ return nil
+}
+
+// NATPostSNATTCP tests that the source port/IP in the packets are modified as
+// expected.
+type NATPostSNATTCP struct{ localCase }
+
+var _ TestCase = (*NATPostSNATTCP)(nil)
+
+// Name implements TestCase.Name.
+func (*NATPostSNATTCP) Name() string {
+ return "NATPostSNATTCP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (*NATPostSNATTCP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ addrs, err := getInterfaceAddrs(ipv6)
+ if err != nil {
+ return err
+ }
+ var source string
+ for _, addr := range addrs {
+ if addr.To4() != nil {
+ if !ipv6 {
+ source = fmt.Sprintf("%s:%d", addr, snatPort)
+ }
+ } else if ipv6 && addr.IsGlobalUnicast() {
+ source = fmt.Sprintf("[%s]:%d", addr, snatPort)
+ }
+ }
+ if source == "" {
+ return fmt.Errorf("can't find any interface address to use")
+ }
+
+ if err := natTable(ipv6, "-A", "POSTROUTING", "-p", "tcp", "-j", "SNAT", "--to-source", source); err != nil {
+ return err
+ }
+ return connectTCP(ctx, ip, acceptPort, ipv6)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (*NATPostSNATTCP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ remote, err := listenTCPFrom(ctx, acceptPort, ipv6)
+ if err != nil {
+ return err
+ }
+ HostStr, portStr, err := net.SplitHostPort(remote.String())
+ if err != nil {
+ return err
+ }
+ if got, want := HostStr, ip.String(); got != want {
+ return fmt.Errorf("got remote address = %s, want = %s", got, want)
+ }
+ port, err := strconv.ParseInt(portStr, 10, 0)
+ if err != nil {
+ return err
+ }
+ if got, want := int(port), snatPort; got != want {
+ return fmt.Errorf("got remote port = %d, want = %d", got, want)
+ }
+ return nil
+}