summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-10-12 19:36:55 -0700
committergVisor bot <gvisor-bot@google.com>2021-10-12 19:39:10 -0700
commit747cb92460bc30983263fcd85562a8586842d824 (patch)
tree5087778a134e917381947dbc6ac4c441a543ece1
parent049fa8ea9999799cc304fe811ca8028a195be493 (diff)
Support Twice NAT
This CL allows both SNAT and DNAT targets to be performed on the same packet. Fixes #5696. PiperOrigin-RevId: 402714738
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go2
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go2
-rw-r--r--pkg/tcpip/stack/conntrack.go24
-rw-r--r--pkg/tcpip/stack/iptables.go53
-rw-r--r--pkg/tcpip/stack/iptables_targets.go5
-rw-r--r--pkg/tcpip/stack/packet_buffer.go13
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go160
7 files changed, 204 insertions, 55 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 9b71738ae..6e52cc9bb 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -439,7 +439,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// We should do this for every packet, rather than only NATted packets, but
// removing this check short circuits broadcasts before they are sent out to
// other hosts.
- if pkt.NatDone {
+ if pkt.DNATDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
// Since we rewrote the packet but it is being routed back to us, we
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 600e805f8..0406a2e6e 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -761,7 +761,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// We should do this for every packet, rather than only NATted packets, but
// removing this check short circuits broadcasts before they are sent out to
// other hosts.
- if pkt.NatDone {
+ if pkt.DNATDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
// Since we rewrote the packet but it is being routed back to us, we
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 48f290187..c9a8e72a3 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -409,18 +409,19 @@ func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool)
}
}
-func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) {
- if pkt.NatDone {
- return
- }
-
+// handlePacket attempts to handle a packet and perform NAT if the connection
+// has had NAT performed on it.
+//
+// Returns true if the packet can skip the NAT table.
+func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
transportHeader, ok := getTransportHeader(pkt)
if !ok {
- return
+ return false
}
fullChecksum := false
updatePseudoHeader := false
+ natDone := &pkt.SNATDone
dnat := false
switch hook {
case Prerouting:
@@ -429,11 +430,13 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) {
fullChecksum = true
updatePseudoHeader = true
+ natDone = &pkt.DNATDone
dnat = true
case Input:
case Forward:
panic("should not handle packet in the forwarding hook")
case Output:
+ natDone = &pkt.DNATDone
dnat = true
fallthrough
case Postrouting:
@@ -447,6 +450,10 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) {
panic(fmt.Sprintf("unrecognized hook = %d", hook))
}
+ if *natDone {
+ panic(fmt.Sprintf("packet already had NAT(dnat=%t) performed at hook=%s; pkt=%#v", dnat, hook, pkt))
+ }
+
// TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
// validated if checksum offloading is off. It may require IP defrag if the
// packets are fragmented.
@@ -490,7 +497,7 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) {
return tuple.id(), true
}()
if !performManip {
- return
+ return false
}
newPort := tid.dstPort
@@ -510,7 +517,8 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) {
newAddr,
)
- pkt.NatDone = true
+ *natDone = true
+ return true
}
// bucket gets the conntrack bucket for a tupleID.
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 5808be685..0baa378ea 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -277,10 +277,7 @@ func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndp
return true
}
- if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil {
- pkt.tuple = t
- t.conn.handlePacket(pkt, hook, nil /* route */)
- }
+ pkt.tuple = it.connections.getConnOrMaybeInsertNoop(pkt)
return it.check(hook, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */)
}
@@ -298,10 +295,6 @@ func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool {
return true
}
- if t := pkt.tuple; t != nil {
- t.conn.handlePacket(pkt, hook, nil /* route */)
- }
-
ret := it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */)
if t := pkt.tuple; t != nil {
t.conn.finalize()
@@ -336,10 +329,7 @@ func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string)
return true
}
- if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil {
- pkt.tuple = t
- t.conn.handlePacket(pkt, hook, r)
- }
+ pkt.tuple = it.connections.getConnOrMaybeInsertNoop(pkt)
return it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName)
}
@@ -357,10 +347,6 @@ func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP Addr
return true
}
- if t := pkt.tuple; t != nil {
- t.conn.handlePacket(pkt, hook, r)
- }
-
ret := it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName)
if t := pkt.tuple; t != nil {
t.conn.finalize()
@@ -396,9 +382,7 @@ func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP Addr
// Go through each table containing the hook.
priorities := it.priorities[hook]
for _, tableID := range priorities {
- // If handlePacket already NATed the packet, we don't need to
- // check the NAT table.
- if tableID == NATID && pkt.NatDone {
+ if t := pkt.tuple; t != nil && tableID == NATID && t.conn.handlePacket(pkt, hook, r) {
continue
}
var table Table
@@ -476,7 +460,7 @@ func (it *IPTables) startReaper(interval time.Duration) {
func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
return checkPackets(pkts, func(pkt *PacketBuffer) bool {
return it.CheckOutput(pkt, r, outNicName)
- })
+ }, true /* dnat */)
}
// CheckPostroutingPackets performs the postrouting hook on the packets.
@@ -487,24 +471,27 @@ func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicNa
func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, addressEP AddressableEndpoint, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
return checkPackets(pkts, func(pkt *PacketBuffer) bool {
return it.CheckPostrouting(pkt, r, addressEP, outNicName)
- })
+ }, false /* dnat */)
}
-func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool, dnat bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if !pkt.NatDone {
- if ok := f(pkt); !ok {
- if drop == nil {
- drop = make(map[*PacketBuffer]struct{})
- }
- drop[pkt] = struct{}{}
+ natDone := &pkt.SNATDone
+ if dnat {
+ natDone = &pkt.DNATDone
+ }
+
+ if ok := f(pkt); !ok {
+ if drop == nil {
+ drop = make(map[*PacketBuffer]struct{})
}
- if pkt.NatDone {
- if natPkts == nil {
- natPkts = make(map[*PacketBuffer]struct{})
- }
- natPkts[pkt] = struct{}{}
+ drop[pkt] = struct{}{}
+ }
+ if *natDone {
+ if natPkts == nil {
+ natPkts = make(map[*PacketBuffer]struct{})
}
+ natPkts[pkt] = struct{}{}
}
}
return drop, natPkts
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 85490e2d4..ef515bdd2 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -175,11 +175,6 @@ type SNATTarget struct {
}
func natAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) (RuleVerdict, int) {
- // Packet is already manipulated.
- if pkt.NatDone {
- return RuleAccept, 0
- }
-
// Drop the packet if network and transport header are not set.
if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
return RuleDrop, 0
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 888a8bd9d..c4a4bbd22 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -126,9 +126,13 @@ type PacketBuffer struct {
EgressRoute RouteInfo
GSOOptions GSO
- // NatDone indicates if the packet has been manipulated as per NAT
- // iptables rule.
- NatDone bool
+ // SNATDone indicates if the packet's source has been manipulated as per
+ // iptables NAT table.
+ SNATDone bool
+
+ // DNATDone indicates if the packet's destination has been manipulated as per
+ // iptables NAT table.
+ DNATDone bool
// PktType indicates the SockAddrLink.PacketType of the packet as defined in
// https://www.man7.org/linux/man-pages/man7/packet.7.html.
@@ -298,7 +302,8 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
Owner: pk.Owner,
GSOOptions: pk.GSOOptions,
NetworkProtocolNumber: pk.NetworkProtocolNumber,
- NatDone: pk.NatDone,
+ DNATDone: pk.DNATDone,
+ SNATDone: pk.SNATDone,
TransportProtocolNumber: pk.TransportProtocolNumber,
PktType: pk.PktType,
NICID: pk.NICID,
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 7f872c271..957a779bf 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -1285,19 +1285,109 @@ func TestNAT(t *testing.T) {
},
}
+ setupTwiceNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, dnatAddr tcpip.Address, snatTarget stack.Target) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+ ipt := s.IPTables()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ // Prerouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transProto,
+ CheckProtocol: true,
+ InputInterface: utils.RouterNIC2Name,
+ },
+ Target: &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: listenPort},
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Input
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Forward
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Output
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Postrouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transProto,
+ CheckProtocol: true,
+ OutputInterface: utils.RouterNIC1Name,
+ },
+ Target: snatTarget,
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: 0,
+ stack.Input: 2,
+ stack.Forward: 3,
+ stack.Output: 4,
+ stack.Postrouting: 5,
+ },
+ }
+
+ if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
+ }
+ }
+ twiceNATTypes := []natType{
+ {
+ name: "DNAT-Masquerade",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) {
+ t.Helper()
+
+ setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.MasqueradeTarget{NetworkProtocol: netProto})
+ },
+ },
+ {
+ name: "DNAT-SNAT",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) {
+ t.Helper()
+
+ setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr})
+ },
+ },
+ }
+
tests := []struct {
name string
netProto tcpip.NetworkProtocolNumber
// Setups up the stacks in such a way that:
//
// - Host2 is the client for all tests.
- // - Host1 is the server when performing SNAT
+ // - When performing SNAT only:
+ // + Host1 is the server.
// + NAT will transform client-originating packets' source addresses to
// the router's NIC1's address before reaching Host1.
- // - Router is the server when performing DNAT (client will still attempt to
- // send packets to Host1).
+ // - When performing DNAT only:
+ // + Router is the server.
+ // + Client will send packets directed to Host1.
// + NAT will transform client-originating packets' destination addresses
// to the router's NIC2's address.
+ // - When performing Twice-NAT:
+ // + Host1 is the server.
+ // + Client will send packets directed to router's NIC2.
+ // + NAT will transform client originating packets' destination addresses
+ // to Host1's address.
+ // + NAT will transform client-originating packets' source addresses to
+ // the router's NIC1's address before reaching Host1.
epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
natTypes []natType
}{
@@ -1370,6 +1460,38 @@ func TestNAT(t *testing.T) {
natTypes: dnatTypes,
},
{
+ name: "IPv4 Twice-NAT",
+ netProto: ipv4.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ listenerStack := host1Stack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address
+ clientConnectPort := serverAddr.Port
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: twiceNATTypes,
+ },
+ {
name: "IPv6 SNAT",
netProto: ipv6.ProtocolNumber,
epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
@@ -1437,6 +1559,38 @@ func TestNAT(t *testing.T) {
},
natTypes: dnatTypes,
},
+ {
+ name: "IPv6 Twice-NAT",
+ netProto: ipv6.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ listenerStack := host1Stack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address
+ clientConnectPort := serverAddr.Port
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: twiceNATTypes,
+ },
}
subTests := []struct {