summaryrefslogtreecommitdiffhomepage
path: root/test/iptables
diff options
context:
space:
mode:
Diffstat (limited to 'test/iptables')
-rw-r--r--test/iptables/iptables_test.go8
-rw-r--r--test/iptables/iptables_util.go61
-rw-r--r--test/iptables/nat.go122
3 files changed, 174 insertions, 17 deletions
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index d6c69a319..04d112134 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -456,3 +456,11 @@ func TestNATPreRECVORIGDSTADDR(t *testing.T) {
func TestNATOutRECVORIGDSTADDR(t *testing.T) {
singleTest(t, &NATOutRECVORIGDSTADDR{})
}
+
+func TestNATPostSNATUDP(t *testing.T) {
+ singleTest(t, &NATPostSNATUDP{})
+}
+
+func TestNATPostSNATTCP(t *testing.T) {
+ singleTest(t, &NATPostSNATTCP{})
+}
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
index bba17b894..4590e169d 100644
--- a/test/iptables/iptables_util.go
+++ b/test/iptables/iptables_util.go
@@ -69,29 +69,41 @@ func tableRules(ipv6 bool, table string, argsList [][]string) error {
return nil
}
-// listenUDP listens on a UDP port and returns the value of net.Conn.Read() for
-// the first read on that port.
+// listenUDP listens on a UDP port and returns nil if the first read from that
+// port is successful.
func listenUDP(ctx context.Context, port int, ipv6 bool) error {
+ _, err := listenUDPFrom(ctx, port, ipv6)
+ return err
+}
+
+// listenUDPFrom listens on a UDP port and returns the sender's UDP address if
+// the first read from that port is successful.
+func listenUDPFrom(ctx context.Context, port int, ipv6 bool) (*net.UDPAddr, error) {
localAddr := net.UDPAddr{
Port: port,
}
conn, err := net.ListenUDP(udpNetwork(ipv6), &localAddr)
if err != nil {
- return err
+ return nil, err
}
defer conn.Close()
- ch := make(chan error)
+ type result struct {
+ remoteAddr *net.UDPAddr
+ err error
+ }
+
+ ch := make(chan result)
go func() {
- _, err = conn.Read([]byte{0})
- ch <- err
+ _, remoteAddr, err := conn.ReadFromUDP([]byte{0})
+ ch <- result{remoteAddr, err}
}()
select {
- case err := <-ch:
- return err
+ case res := <-ch:
+ return res.remoteAddr, res.err
case <-ctx.Done():
- return ctx.Err()
+ return nil, fmt.Errorf("timed out reading from %s: %w", &localAddr, ctx.Err())
}
}
@@ -125,8 +137,16 @@ func sendUDPLoop(ctx context.Context, ip net.IP, port int, ipv6 bool) error {
}
}
-// listenTCP listens for connections on a TCP port.
+// listenTCP listens for connections on a TCP port, and returns nil if a
+// connection is established.
func listenTCP(ctx context.Context, port int, ipv6 bool) error {
+ _, err := listenTCPFrom(ctx, port, ipv6)
+ return err
+}
+
+// listenTCP listens for connections on a TCP port, and returns the remote
+// TCP address if a connection is established.
+func listenTCPFrom(ctx context.Context, port int, ipv6 bool) (net.Addr, error) {
localAddr := net.TCPAddr{
Port: port,
}
@@ -134,23 +154,32 @@ func listenTCP(ctx context.Context, port int, ipv6 bool) error {
// Starts listening on port.
lConn, err := net.ListenTCP(tcpNetwork(ipv6), &localAddr)
if err != nil {
- return err
+ return nil, err
}
defer lConn.Close()
+ type result struct {
+ remoteAddr net.Addr
+ err error
+ }
+
// Accept connections on port.
- ch := make(chan error)
+ ch := make(chan result)
go func() {
conn, err := lConn.AcceptTCP()
- ch <- err
+ var remoteAddr net.Addr
+ if err == nil {
+ remoteAddr = conn.RemoteAddr()
+ }
+ ch <- result{remoteAddr, err}
conn.Close()
}()
select {
- case err := <-ch:
- return err
+ case res := <-ch:
+ return res.remoteAddr, res.err
case <-ctx.Done():
- return fmt.Errorf("timed out waiting for a connection at %#v: %w", localAddr, ctx.Err())
+ return nil, fmt.Errorf("timed out waiting for a connection at %s: %w", &localAddr, ctx.Err())
}
}
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index 0776639a7..0f25b6a18 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"net"
+ "strconv"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/binary"
@@ -48,6 +49,8 @@ func init() {
RegisterTestCase(&NATOutOriginalDst{})
RegisterTestCase(&NATPreRECVORIGDSTADDR{})
RegisterTestCase(&NATOutRECVORIGDSTADDR{})
+ RegisterTestCase(&NATPostSNATUDP{})
+ RegisterTestCase(&NATPostSNATTCP{})
}
// NATPreRedirectUDPPort tests that packets are redirected to different port.
@@ -486,7 +489,12 @@ func (*NATLoopbackSkipsPrerouting) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (*NATLoopbackSkipsPrerouting) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
// Redirect anything sent to localhost to an unused port.
- dest := []byte{127, 0, 0, 1}
+ var dest net.IP
+ if ipv6 {
+ dest = net.IPv6loopback
+ } else {
+ dest = net.IPv4(127, 0, 0, 1)
+ }
if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil {
return err
}
@@ -915,3 +923,115 @@ func addrMatches6(got unix.RawSockaddrInet6, wantAddrs []net.IP, port uint16) er
}
return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs)
}
+
+const (
+ snatAddrV4 = "194.236.50.155"
+ snatAddrV6 = "2a0a::1"
+ snatPort = 43
+)
+
+// NATPostSNATUDP tests that the source port/IP in the packets are modified as expected.
+type NATPostSNATUDP struct{ localCase }
+
+var _ TestCase = (*NATPostSNATUDP)(nil)
+
+// Name implements TestCase.Name.
+func (*NATPostSNATUDP) Name() string {
+ return "NATPostSNATUDP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (*NATPostSNATUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ var source string
+ if ipv6 {
+ source = fmt.Sprintf("[%s]:%d", snatAddrV6, snatPort)
+ } else {
+ source = fmt.Sprintf("%s:%d", snatAddrV4, snatPort)
+ }
+
+ if err := natTable(ipv6, "-A", "POSTROUTING", "-p", "udp", "-j", "SNAT", "--to-source", source); err != nil {
+ return err
+ }
+ return sendUDPLoop(ctx, ip, acceptPort, ipv6)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (*NATPostSNATUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ remote, err := listenUDPFrom(ctx, acceptPort, ipv6)
+ if err != nil {
+ return err
+ }
+ var snatAddr string
+ if ipv6 {
+ snatAddr = snatAddrV6
+ } else {
+ snatAddr = snatAddrV4
+ }
+ if got, want := remote.IP, net.ParseIP(snatAddr); !got.Equal(want) {
+ return fmt.Errorf("got remote address = %s, want = %s", got, want)
+ }
+ if got, want := remote.Port, snatPort; got != want {
+ return fmt.Errorf("got remote port = %d, want = %d", got, want)
+ }
+ return nil
+}
+
+// NATPostSNATTCP tests that the source port/IP in the packets are modified as
+// expected.
+type NATPostSNATTCP struct{ localCase }
+
+var _ TestCase = (*NATPostSNATTCP)(nil)
+
+// Name implements TestCase.Name.
+func (*NATPostSNATTCP) Name() string {
+ return "NATPostSNATTCP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (*NATPostSNATTCP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ addrs, err := getInterfaceAddrs(ipv6)
+ if err != nil {
+ return err
+ }
+ var source string
+ for _, addr := range addrs {
+ if addr.To4() != nil {
+ if !ipv6 {
+ source = fmt.Sprintf("%s:%d", addr, snatPort)
+ }
+ } else if ipv6 && addr.IsGlobalUnicast() {
+ source = fmt.Sprintf("[%s]:%d", addr, snatPort)
+ }
+ }
+ if source == "" {
+ return fmt.Errorf("can't find any interface address to use")
+ }
+
+ if err := natTable(ipv6, "-A", "POSTROUTING", "-p", "tcp", "-j", "SNAT", "--to-source", source); err != nil {
+ return err
+ }
+ return connectTCP(ctx, ip, acceptPort, ipv6)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (*NATPostSNATTCP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ remote, err := listenTCPFrom(ctx, acceptPort, ipv6)
+ if err != nil {
+ return err
+ }
+ HostStr, portStr, err := net.SplitHostPort(remote.String())
+ if err != nil {
+ return err
+ }
+ if got, want := HostStr, ip.String(); got != want {
+ return fmt.Errorf("got remote address = %s, want = %s", got, want)
+ }
+ port, err := strconv.ParseInt(portStr, 10, 0)
+ if err != nil {
+ return err
+ }
+ if got, want := int(port), snatPort; got != want {
+ return fmt.Errorf("got remote port = %d, want = %d", got, want)
+ }
+ return nil
+}