summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/tests/integration/iptables_test.go
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-09-28 14:04:38 -0700
committergVisor bot <gvisor-bot@google.com>2021-09-28 14:07:35 -0700
commited083bac408cfeb68ae401b8cec7cc1be5722562 (patch)
tree5678c9e739df9906d2c187de5291ec50f000815f /pkg/tcpip/tests/integration/iptables_test.go
parente251f6cc5c2641e000846b35e4aa7f0d41c4b319 (diff)
Support naive Masquerade NAT target
* Does not accept a port range (Issue #5772). * Does not support checking for tuple conflits (Issue #5773). PiperOrigin-RevId: 399524088
Diffstat (limited to 'pkg/tcpip/tests/integration/iptables_test.go')
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go274
1 files changed, 150 insertions, 124 deletions
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index bdf4a64b9..f01e2b128 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -1197,24 +1197,15 @@ func TestSNAT(t *testing.T) {
tests := []struct {
name string
+ netProto tcpip.NetworkProtocolNumber
epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
}{
{
- name: "IPv4 host1 server with host2 client",
+ name: "IPv4 host1 server with host2 client",
+ netProto: ipv4.ProtocolNumber,
epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
t.Helper()
- ipt := routerStack.IPTables()
- filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Postrouting]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name}
- filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv4.ProtocolNumber, Addr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address}
- // Make sure the packet is not dropped by the next rule.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, false, err)
- }
-
ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber)
ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
return endpointAndAddresses{
@@ -1231,21 +1222,11 @@ func TestSNAT(t *testing.T) {
},
},
{
- name: "IPv6 host1 server with host2 client",
+ name: "IPv6 host1 server with host2 client",
+ netProto: ipv6.ProtocolNumber,
epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
t.Helper()
- ipt := routerStack.IPTables()
- filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Postrouting]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name}
- filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv6.ProtocolNumber, Addr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address}
- // Make sure the packet is not dropped by the next rule.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, true, err)
- }
-
ep1, ep1WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
return endpointAndAddresses{
@@ -1324,120 +1305,165 @@ func TestSNAT(t *testing.T) {
},
}
+ setupNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, target stack.Target) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.NATID, ipv6)
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name}
+ filter.Rules[ruleIdx].Target = target
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
+ }
+ }
+
+ natTypes := []struct {
+ name string
+ setupNAT func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber, tcpip.Address)
+ }{
+ {
+ name: "SNAT",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, natToAddr tcpip.Address) {
+ t.Helper()
+
+ setupNAT(t, s, netProto, &stack.SNATTarget{NetworkProtocol: netProto, Addr: natToAddr})
+ },
+ },
+ {
+ name: "Masquerade",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, natToAddr tcpip.Address) {
+ t.Helper()
+
+ setupNAT(t, s, netProto, &stack.MasqueradeTarget{NetworkProtocol: netProto})
+ },
+ },
+ }
+
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
for _, subTest := range subTests {
t.Run(subTest.name, func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- }
+ for _, natType := range natTypes {
+ t.Run(natType.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ }
- host1Stack := stack.New(stackOpts)
- routerStack := stack.New(stackOpts)
- host2Stack := stack.New(stackOpts)
- utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
+ host1Stack := stack.New(stackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+ utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
- epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
- serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
- if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
- t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
- }
- clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
- if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
- t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
- }
+ epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
- if subTest.setupServer != nil {
- subTest.setupServer(t, epsAndAddrs.serverEP)
- }
- {
- err := epsAndAddrs.clientEP.Connect(serverAddr)
- if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
- t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff)
- }
- }
- nattedClientAddr := tcpip.FullAddress{Addr: epsAndAddrs.nattedClientAddr}
- if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
- t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
- } else {
- nattedClientAddr.Port = addr.Port
- }
+ natType.setupNAT(t, routerStack, test.netProto, epsAndAddrs.nattedClientAddr)
- serverEP := epsAndAddrs.serverEP
- serverCH := epsAndAddrs.serverReadableCH
- if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, nattedClientAddr); ep != nil {
- defer ep.Close()
- serverEP = ep
- serverCH = ch
- }
+ serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
+ if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
+ t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
+ }
+ clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
+ if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
+ }
- write := func(ep tcpip.Endpoint, data []byte) {
- t.Helper()
-
- var r bytes.Reader
- r.Reset(data)
- var wOpts tcpip.WriteOptions
- n, err := ep.Write(&r, wOpts)
- if err != nil {
- t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
- }
- if want := int64(len(data)); n != want {
- t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
- }
- }
+ if subTest.setupServer != nil {
+ subTest.setupServer(t, epsAndAddrs.serverEP)
+ }
+ {
+ err := epsAndAddrs.clientEP.Connect(serverAddr)
+ if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
+ t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff)
+ }
+ }
+ nattedClientAddr := tcpip.FullAddress{Addr: epsAndAddrs.nattedClientAddr}
+ if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
+ } else {
+ nattedClientAddr.Port = addr.Port
+ }
- read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
- t.Helper()
-
- var buf bytes.Buffer
- var res tcpip.ReadResult
- for {
- var err tcpip.Error
- opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
- res, err = ep.Read(&buf, opts)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- <-ch
- continue
+ serverEP := epsAndAddrs.serverEP
+ serverCH := epsAndAddrs.serverReadableCH
+ if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, nattedClientAddr); ep != nil {
+ defer ep.Close()
+ serverEP = ep
+ serverCH = ch
}
- if err != nil {
- t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+
+ write := func(ep tcpip.Endpoint, data []byte) {
+ t.Helper()
+
+ var r bytes.Reader
+ r.Reset(data)
+ var wOpts tcpip.WriteOptions
+ n, err := ep.Write(&r, wOpts)
+ if err != nil {
+ t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
+ }
+ if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
+ }
}
- break
- }
-
- readResult := tcpip.ReadResult{
- Count: len(data),
- Total: len(data),
- }
- if subTest.needRemoteAddr {
- readResult.RemoteAddr = expectedFrom
- }
- if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
- "ControlMessages",
- "RemoteAddr.NIC",
- )); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
- }
-
- if t.Failed() {
- t.FailNow()
- }
- }
- {
- data := []byte{1, 2, 3, 4}
- write(epsAndAddrs.clientEP, data)
- read(serverCH, serverEP, data, nattedClientAddr)
- }
+ read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
+ t.Helper()
+
+ var buf bytes.Buffer
+ var res tcpip.ReadResult
+ for {
+ var err tcpip.Error
+ opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
+ res, err = ep.Read(&buf, opts)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ }
+ break
+ }
+
+ readResult := tcpip.ReadResult{
+ Count: len(data),
+ Total: len(data),
+ }
+ if subTest.needRemoteAddr {
+ readResult.RemoteAddr = expectedFrom
+ }
+ if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
- {
- data := []byte{5, 6, 7, 8, 9, 10, 11, 12}
- write(serverEP, data)
- read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
+ {
+ data := []byte{1, 2, 3, 4}
+ write(epsAndAddrs.clientEP, data)
+ read(serverCH, serverEP, data, nattedClientAddr)
+ }
+
+ {
+ data := []byte{5, 6, 7, 8, 9, 10, 11, 12}
+ write(serverEP, data)
+ read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
+ }
+ })
}
})
}