diff options
Diffstat (limited to 'test/iptables')
-rw-r--r-- | test/iptables/iptables_test.go | 8 | ||||
-rw-r--r-- | test/iptables/iptables_util.go | 61 | ||||
-rw-r--r-- | test/iptables/nat.go | 122 |
3 files changed, 174 insertions, 17 deletions
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index d6c69a319..04d112134 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -456,3 +456,11 @@ func TestNATPreRECVORIGDSTADDR(t *testing.T) { func TestNATOutRECVORIGDSTADDR(t *testing.T) { singleTest(t, &NATOutRECVORIGDSTADDR{}) } + +func TestNATPostSNATUDP(t *testing.T) { + singleTest(t, &NATPostSNATUDP{}) +} + +func TestNATPostSNATTCP(t *testing.T) { + singleTest(t, &NATPostSNATTCP{}) +} diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index bba17b894..4590e169d 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -69,29 +69,41 @@ func tableRules(ipv6 bool, table string, argsList [][]string) error { return nil } -// listenUDP listens on a UDP port and returns the value of net.Conn.Read() for -// the first read on that port. +// listenUDP listens on a UDP port and returns nil if the first read from that +// port is successful. func listenUDP(ctx context.Context, port int, ipv6 bool) error { + _, err := listenUDPFrom(ctx, port, ipv6) + return err +} + +// listenUDPFrom listens on a UDP port and returns the sender's UDP address if +// the first read from that port is successful. +func listenUDPFrom(ctx context.Context, port int, ipv6 bool) (*net.UDPAddr, error) { localAddr := net.UDPAddr{ Port: port, } conn, err := net.ListenUDP(udpNetwork(ipv6), &localAddr) if err != nil { - return err + return nil, err } defer conn.Close() - ch := make(chan error) + type result struct { + remoteAddr *net.UDPAddr + err error + } + + ch := make(chan result) go func() { - _, err = conn.Read([]byte{0}) - ch <- err + _, remoteAddr, err := conn.ReadFromUDP([]byte{0}) + ch <- result{remoteAddr, err} }() select { - case err := <-ch: - return err + case res := <-ch: + return res.remoteAddr, res.err case <-ctx.Done(): - return ctx.Err() + return nil, fmt.Errorf("timed out reading from %s: %w", &localAddr, ctx.Err()) } } @@ -125,8 +137,16 @@ func sendUDPLoop(ctx context.Context, ip net.IP, port int, ipv6 bool) error { } } -// listenTCP listens for connections on a TCP port. +// listenTCP listens for connections on a TCP port, and returns nil if a +// connection is established. func listenTCP(ctx context.Context, port int, ipv6 bool) error { + _, err := listenTCPFrom(ctx, port, ipv6) + return err +} + +// listenTCP listens for connections on a TCP port, and returns the remote +// TCP address if a connection is established. +func listenTCPFrom(ctx context.Context, port int, ipv6 bool) (net.Addr, error) { localAddr := net.TCPAddr{ Port: port, } @@ -134,23 +154,32 @@ func listenTCP(ctx context.Context, port int, ipv6 bool) error { // Starts listening on port. lConn, err := net.ListenTCP(tcpNetwork(ipv6), &localAddr) if err != nil { - return err + return nil, err } defer lConn.Close() + type result struct { + remoteAddr net.Addr + err error + } + // Accept connections on port. - ch := make(chan error) + ch := make(chan result) go func() { conn, err := lConn.AcceptTCP() - ch <- err + var remoteAddr net.Addr + if err == nil { + remoteAddr = conn.RemoteAddr() + } + ch <- result{remoteAddr, err} conn.Close() }() select { - case err := <-ch: - return err + case res := <-ch: + return res.remoteAddr, res.err case <-ctx.Done(): - return fmt.Errorf("timed out waiting for a connection at %#v: %w", localAddr, ctx.Err()) + return nil, fmt.Errorf("timed out waiting for a connection at %s: %w", &localAddr, ctx.Err()) } } diff --git a/test/iptables/nat.go b/test/iptables/nat.go index 0776639a7..0f25b6a18 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "net" + "strconv" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/binary" @@ -48,6 +49,8 @@ func init() { RegisterTestCase(&NATOutOriginalDst{}) RegisterTestCase(&NATPreRECVORIGDSTADDR{}) RegisterTestCase(&NATOutRECVORIGDSTADDR{}) + RegisterTestCase(&NATPostSNATUDP{}) + RegisterTestCase(&NATPostSNATTCP{}) } // NATPreRedirectUDPPort tests that packets are redirected to different port. @@ -486,7 +489,12 @@ func (*NATLoopbackSkipsPrerouting) Name() string { // ContainerAction implements TestCase.ContainerAction. func (*NATLoopbackSkipsPrerouting) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect anything sent to localhost to an unused port. - dest := []byte{127, 0, 0, 1} + var dest net.IP + if ipv6 { + dest = net.IPv6loopback + } else { + dest = net.IPv4(127, 0, 0, 1) + } if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { return err } @@ -915,3 +923,115 @@ func addrMatches6(got unix.RawSockaddrInet6, wantAddrs []net.IP, port uint16) er } return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs) } + +const ( + snatAddrV4 = "194.236.50.155" + snatAddrV6 = "2a0a::1" + snatPort = 43 +) + +// NATPostSNATUDP tests that the source port/IP in the packets are modified as expected. +type NATPostSNATUDP struct{ localCase } + +var _ TestCase = (*NATPostSNATUDP)(nil) + +// Name implements TestCase.Name. +func (*NATPostSNATUDP) Name() string { + return "NATPostSNATUDP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (*NATPostSNATUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + var source string + if ipv6 { + source = fmt.Sprintf("[%s]:%d", snatAddrV6, snatPort) + } else { + source = fmt.Sprintf("%s:%d", snatAddrV4, snatPort) + } + + if err := natTable(ipv6, "-A", "POSTROUTING", "-p", "udp", "-j", "SNAT", "--to-source", source); err != nil { + return err + } + return sendUDPLoop(ctx, ip, acceptPort, ipv6) +} + +// LocalAction implements TestCase.LocalAction. +func (*NATPostSNATUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + remote, err := listenUDPFrom(ctx, acceptPort, ipv6) + if err != nil { + return err + } + var snatAddr string + if ipv6 { + snatAddr = snatAddrV6 + } else { + snatAddr = snatAddrV4 + } + if got, want := remote.IP, net.ParseIP(snatAddr); !got.Equal(want) { + return fmt.Errorf("got remote address = %s, want = %s", got, want) + } + if got, want := remote.Port, snatPort; got != want { + return fmt.Errorf("got remote port = %d, want = %d", got, want) + } + return nil +} + +// NATPostSNATTCP tests that the source port/IP in the packets are modified as +// expected. +type NATPostSNATTCP struct{ localCase } + +var _ TestCase = (*NATPostSNATTCP)(nil) + +// Name implements TestCase.Name. +func (*NATPostSNATTCP) Name() string { + return "NATPostSNATTCP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (*NATPostSNATTCP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + addrs, err := getInterfaceAddrs(ipv6) + if err != nil { + return err + } + var source string + for _, addr := range addrs { + if addr.To4() != nil { + if !ipv6 { + source = fmt.Sprintf("%s:%d", addr, snatPort) + } + } else if ipv6 && addr.IsGlobalUnicast() { + source = fmt.Sprintf("[%s]:%d", addr, snatPort) + } + } + if source == "" { + return fmt.Errorf("can't find any interface address to use") + } + + if err := natTable(ipv6, "-A", "POSTROUTING", "-p", "tcp", "-j", "SNAT", "--to-source", source); err != nil { + return err + } + return connectTCP(ctx, ip, acceptPort, ipv6) +} + +// LocalAction implements TestCase.LocalAction. +func (*NATPostSNATTCP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + remote, err := listenTCPFrom(ctx, acceptPort, ipv6) + if err != nil { + return err + } + HostStr, portStr, err := net.SplitHostPort(remote.String()) + if err != nil { + return err + } + if got, want := HostStr, ip.String(); got != want { + return fmt.Errorf("got remote address = %s, want = %s", got, want) + } + port, err := strconv.ParseInt(portStr, 10, 0) + if err != nil { + return err + } + if got, want := int(port), snatPort; got != want { + return fmt.Errorf("got remote port = %d, want = %d", got, want) + } + return nil +} |