summaryrefslogtreecommitdiffhomepage
path: root/test/iptables
diff options
context:
space:
mode:
Diffstat (limited to 'test/iptables')
-rw-r--r--test/iptables/BUILD2
-rw-r--r--test/iptables/iptables_test.go8
-rw-r--r--test/iptables/nat.go265
3 files changed, 249 insertions, 26 deletions
diff --git a/test/iptables/BUILD b/test/iptables/BUILD
index 66453772a..ae4bba847 100644
--- a/test/iptables/BUILD
+++ b/test/iptables/BUILD
@@ -15,7 +15,9 @@ go_library(
],
visibility = ["//test/iptables:__subpackages__"],
deps = [
+ "//pkg/binary",
"//pkg/test/testutil",
+ "//pkg/usermem",
],
)
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index 4733146c0..9a4f60a9a 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -424,3 +424,11 @@ func TestNATPreOriginalDst(t *testing.T) {
func TestNATOutOriginalDst(t *testing.T) {
singleTest(t, NATOutOriginalDst{})
}
+
+func TestNATPreRECVORIGDSTADDR(t *testing.T) {
+ singleTest(t, NATPreRECVORIGDSTADDR{})
+}
+
+func TestNATOutRECVORIGDSTADDR(t *testing.T) {
+ singleTest(t, NATOutRECVORIGDSTADDR{})
+}
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index 495241482..c3874240f 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -20,6 +20,9 @@ import (
"fmt"
"net"
"syscall"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const redirectPort = 42
@@ -43,6 +46,8 @@ func init() {
RegisterTestCase(NATLoopbackSkipsPrerouting{})
RegisterTestCase(NATPreOriginalDst{})
RegisterTestCase(NATOutOriginalDst{})
+ RegisterTestCase(NATPreRECVORIGDSTADDR{})
+ RegisterTestCase(NATOutRECVORIGDSTADDR{})
}
// NATPreRedirectUDPPort tests that packets are redirected to different port.
@@ -538,9 +543,9 @@ func (NATOutOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool)
}
func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.IP) error {
- // The net package doesn't give guarantee access to the connection's
+ // The net package doesn't give guaranteed access to the connection's
// underlying FD, and thus we cannot call getsockopt. We have to use
- // traditional syscalls for SO_ORIGINAL_DST.
+ // traditional syscalls.
// Create the listening socket, bind, listen, and accept.
family := syscall.AF_INET
@@ -609,36 +614,14 @@ func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.
if err != nil {
return err
}
- // The original destination could be any of our IPs.
- for _, dst := range originalDsts {
- want := syscall.RawSockaddrInet6{
- Family: syscall.AF_INET6,
- Port: htons(dropPort),
- }
- copy(want.Addr[:], dst.To16())
- if got == want {
- return nil
- }
- }
- return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts)
+ return addrMatches6(got, originalDsts, dropPort)
}
got, err := originalDestination4(connFD)
if err != nil {
return err
}
- // The original destination could be any of our IPs.
- for _, dst := range originalDsts {
- want := syscall.RawSockaddrInet4{
- Family: syscall.AF_INET,
- Port: htons(dropPort),
- }
- copy(want.Addr[:], dst.To4())
- if got == want {
- return nil
- }
- }
- return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts)
+ return addrMatches4(got, originalDsts, dropPort)
}
// loopbackTests runs an iptables rule and ensures that packets sent to
@@ -662,3 +645,233 @@ func loopbackTest(ctx context.Context, ipv6 bool, dest net.IP, args ...string) e
return err
}
}
+
+// NATPreRECVORIGDSTADDR tests that IP{V6}_RECVORIGDSTADDR gets the post-NAT
+// address on the PREROUTING chain.
+type NATPreRECVORIGDSTADDR struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (NATPreRECVORIGDSTADDR) Name() string {
+ return "NATPreRECVORIGDSTADDR"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreRECVORIGDSTADDR) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
+ return err
+ }
+
+ if err := recvWithRECVORIGDSTADDR(ctx, ipv6, nil, redirectPort); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreRECVORIGDSTADDR) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// NATOutRECVORIGDSTADDR tests that IP{V6}_RECVORIGDSTADDR gets the post-NAT
+// address on the OUTPUT chain.
+type NATOutRECVORIGDSTADDR struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (NATOutRECVORIGDSTADDR) Name() string {
+ return "NATOutRECVORIGDSTADDR"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRECVORIGDSTADDR) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := natTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
+ return err
+ }
+
+ sendCh := make(chan error)
+ go func() {
+ // Packets will be sent to a non-container IP and redirected
+ // back to the container.
+ sendCh <- sendUDPLoop(ctx, ip, acceptPort)
+ }()
+
+ expectedIP := &net.IP{127, 0, 0, 1}
+ if ipv6 {
+ expectedIP = &net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
+ }
+ if err := recvWithRECVORIGDSTADDR(ctx, ipv6, expectedIP, redirectPort); err != nil {
+ return err
+ }
+
+ select {
+ case err := <-sendCh:
+ return err
+ default:
+ return nil
+ }
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRECVORIGDSTADDR) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // No-op.
+ return nil
+}
+
+func recvWithRECVORIGDSTADDR(ctx context.Context, ipv6 bool, expectedDst *net.IP, port uint16) error {
+ // The net package doesn't give guaranteed access to a connection's
+ // underlying FD, and thus we cannot call getsockopt. We have to use
+ // traditional syscalls for IP_RECVORIGDSTADDR.
+
+ // Create the listening socket.
+ var (
+ family = syscall.AF_INET
+ level = syscall.SOL_IP
+ option = syscall.IP_RECVORIGDSTADDR
+ bindAddr syscall.Sockaddr = &syscall.SockaddrInet4{
+ Port: int(port),
+ Addr: [4]byte{0, 0, 0, 0}, // INADDR_ANY
+ }
+ )
+ if ipv6 {
+ family = syscall.AF_INET6
+ level = syscall.SOL_IPV6
+ option = 74 // IPV6_RECVORIGDSTADDR, which is missing from the syscall package.
+ bindAddr = &syscall.SockaddrInet6{
+ Port: int(port),
+ Addr: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in6addr_any
+ }
+ }
+ sockfd, err := syscall.Socket(family, syscall.SOCK_DGRAM, 0)
+ if err != nil {
+ return fmt.Errorf("failed Socket(%d, %d, 0): %w", family, syscall.SOCK_DGRAM, err)
+ }
+ defer syscall.Close(sockfd)
+
+ if err := syscall.Bind(sockfd, bindAddr); err != nil {
+ return fmt.Errorf("failed Bind(%d, %+v): %v", sockfd, bindAddr, err)
+ }
+
+ // Enable IP_RECVORIGDSTADDR.
+ if err := syscall.SetsockoptInt(sockfd, level, option, 1); err != nil {
+ return fmt.Errorf("failed SetsockoptByte(%d, %d, %d, 1): %v", sockfd, level, option, err)
+ }
+
+ addrCh := make(chan interface{})
+ errCh := make(chan error)
+ go func() {
+ var addr interface{}
+ var err error
+ if ipv6 {
+ addr, err = recvOrigDstAddr6(sockfd)
+ } else {
+ addr, err = recvOrigDstAddr4(sockfd)
+ }
+ if err != nil {
+ errCh <- err
+ } else {
+ addrCh <- addr
+ }
+ }()
+
+ // Wait to receive a packet.
+ var addr interface{}
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case err := <-errCh:
+ return err
+ case addr = <-addrCh:
+ }
+
+ // Get a list of local IPs to verify that the packet now appears to have
+ // been sent to us.
+ var localAddrs []net.IP
+ if expectedDst != nil {
+ localAddrs = []net.IP{*expectedDst}
+ } else {
+ localAddrs, err = getInterfaceAddrs(ipv6)
+ if err != nil {
+ return fmt.Errorf("failed to get local interfaces: %w", err)
+ }
+ }
+
+ // Verify that the address has the post-NAT port and address.
+ if ipv6 {
+ return addrMatches6(addr.(syscall.RawSockaddrInet6), localAddrs, redirectPort)
+ }
+ return addrMatches4(addr.(syscall.RawSockaddrInet4), localAddrs, redirectPort)
+}
+
+func recvOrigDstAddr4(sockfd int) (syscall.RawSockaddrInet4, error) {
+ buf, err := recvOrigDstAddr(sockfd, syscall.SOL_IP, syscall.SizeofSockaddrInet4)
+ if err != nil {
+ return syscall.RawSockaddrInet4{}, err
+ }
+ var addr syscall.RawSockaddrInet4
+ binary.Unmarshal(buf, usermem.ByteOrder, &addr)
+ return addr, nil
+}
+
+func recvOrigDstAddr6(sockfd int) (syscall.RawSockaddrInet6, error) {
+ buf, err := recvOrigDstAddr(sockfd, syscall.SOL_IP, syscall.SizeofSockaddrInet6)
+ if err != nil {
+ return syscall.RawSockaddrInet6{}, err
+ }
+ var addr syscall.RawSockaddrInet6
+ binary.Unmarshal(buf, usermem.ByteOrder, &addr)
+ return addr, nil
+}
+
+func recvOrigDstAddr(sockfd int, level uintptr, addrSize int) ([]byte, error) {
+ buf := make([]byte, 64)
+ oob := make([]byte, syscall.CmsgSpace(addrSize))
+ for {
+ _, oobn, _, _, err := syscall.Recvmsg(
+ sockfd,
+ buf, // Message buffer.
+ oob, // Out-of-band buffer.
+ 0) // Flags.
+ if errors.Is(err, syscall.EINTR) {
+ continue
+ }
+ if err != nil {
+ return nil, fmt.Errorf("failed when calling Recvmsg: %w", err)
+ }
+ oob = oob[:oobn]
+
+ // Parse out the control message.
+ msgs, err := syscall.ParseSocketControlMessage(oob)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse control message: %w", err)
+ }
+ return msgs[0].Data, nil
+ }
+}
+
+func addrMatches4(got syscall.RawSockaddrInet4, wantAddrs []net.IP, port uint16) error {
+ for _, wantAddr := range wantAddrs {
+ want := syscall.RawSockaddrInet4{
+ Family: syscall.AF_INET,
+ Port: htons(port),
+ }
+ copy(want.Addr[:], wantAddr.To4())
+ if got == want {
+ return nil
+ }
+ }
+ return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs)
+}
+
+func addrMatches6(got syscall.RawSockaddrInet6, wantAddrs []net.IP, port uint16) error {
+ for _, wantAddr := range wantAddrs {
+ want := syscall.RawSockaddrInet6{
+ Family: syscall.AF_INET6,
+ Port: htons(port),
+ }
+ copy(want.Addr[:], wantAddr.To16())
+ if got == want {
+ return nil
+ }
+ }
+ return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs)
+}