From 833ba3590b422d453012e5b2ec2e780211d9caf9 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Tue, 19 Jan 2021 12:10:01 -0800 Subject: Ensure that IP{V6}_RECVORIGDSTADDR yields the post-NAT address and port. PiperOrigin-RevId: 352624174 --- test/iptables/BUILD | 2 + test/iptables/iptables_test.go | 8 ++ test/iptables/nat.go | 265 +++++++++++++++++++++++++++++++++++++---- 3 files changed, 249 insertions(+), 26 deletions(-) diff --git a/test/iptables/BUILD b/test/iptables/BUILD index 66453772a..ae4bba847 100644 --- a/test/iptables/BUILD +++ b/test/iptables/BUILD @@ -15,7 +15,9 @@ go_library( ], visibility = ["//test/iptables:__subpackages__"], deps = [ + "//pkg/binary", "//pkg/test/testutil", + "//pkg/usermem", ], ) diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 4733146c0..9a4f60a9a 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -424,3 +424,11 @@ func TestNATPreOriginalDst(t *testing.T) { func TestNATOutOriginalDst(t *testing.T) { singleTest(t, NATOutOriginalDst{}) } + +func TestNATPreRECVORIGDSTADDR(t *testing.T) { + singleTest(t, NATPreRECVORIGDSTADDR{}) +} + +func TestNATOutRECVORIGDSTADDR(t *testing.T) { + singleTest(t, NATOutRECVORIGDSTADDR{}) +} diff --git a/test/iptables/nat.go b/test/iptables/nat.go index 495241482..c3874240f 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -20,6 +20,9 @@ import ( "fmt" "net" "syscall" + + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/usermem" ) const redirectPort = 42 @@ -43,6 +46,8 @@ func init() { RegisterTestCase(NATLoopbackSkipsPrerouting{}) RegisterTestCase(NATPreOriginalDst{}) RegisterTestCase(NATOutOriginalDst{}) + RegisterTestCase(NATPreRECVORIGDSTADDR{}) + RegisterTestCase(NATOutRECVORIGDSTADDR{}) } // NATPreRedirectUDPPort tests that packets are redirected to different port. @@ -538,9 +543,9 @@ func (NATOutOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) } func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.IP) error { - // The net package doesn't give guarantee access to the connection's + // The net package doesn't give guaranteed access to the connection's // underlying FD, and thus we cannot call getsockopt. We have to use - // traditional syscalls for SO_ORIGINAL_DST. + // traditional syscalls. // Create the listening socket, bind, listen, and accept. family := syscall.AF_INET @@ -609,36 +614,14 @@ func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net. if err != nil { return err } - // The original destination could be any of our IPs. - for _, dst := range originalDsts { - want := syscall.RawSockaddrInet6{ - Family: syscall.AF_INET6, - Port: htons(dropPort), - } - copy(want.Addr[:], dst.To16()) - if got == want { - return nil - } - } - return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) + return addrMatches6(got, originalDsts, dropPort) } got, err := originalDestination4(connFD) if err != nil { return err } - // The original destination could be any of our IPs. - for _, dst := range originalDsts { - want := syscall.RawSockaddrInet4{ - Family: syscall.AF_INET, - Port: htons(dropPort), - } - copy(want.Addr[:], dst.To4()) - if got == want { - return nil - } - } - return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) + return addrMatches4(got, originalDsts, dropPort) } // loopbackTests runs an iptables rule and ensures that packets sent to @@ -662,3 +645,233 @@ func loopbackTest(ctx context.Context, ipv6 bool, dest net.IP, args ...string) e return err } } + +// NATPreRECVORIGDSTADDR tests that IP{V6}_RECVORIGDSTADDR gets the post-NAT +// address on the PREROUTING chain. +type NATPreRECVORIGDSTADDR struct{ containerCase } + +// Name implements TestCase.Name. +func (NATPreRECVORIGDSTADDR) Name() string { + return "NATPreRECVORIGDSTADDR" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreRECVORIGDSTADDR) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { + return err + } + + if err := recvWithRECVORIGDSTADDR(ctx, ipv6, nil, redirectPort); err != nil { + return err + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreRECVORIGDSTADDR) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// NATOutRECVORIGDSTADDR tests that IP{V6}_RECVORIGDSTADDR gets the post-NAT +// address on the OUTPUT chain. +type NATOutRECVORIGDSTADDR struct{ containerCase } + +// Name implements TestCase.Name. +func (NATOutRECVORIGDSTADDR) Name() string { + return "NATOutRECVORIGDSTADDR" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRECVORIGDSTADDR) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { + return err + } + + sendCh := make(chan error) + go func() { + // Packets will be sent to a non-container IP and redirected + // back to the container. + sendCh <- sendUDPLoop(ctx, ip, acceptPort) + }() + + expectedIP := &net.IP{127, 0, 0, 1} + if ipv6 { + expectedIP = &net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + } + if err := recvWithRECVORIGDSTADDR(ctx, ipv6, expectedIP, redirectPort); err != nil { + return err + } + + select { + case err := <-sendCh: + return err + default: + return nil + } +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRECVORIGDSTADDR) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +func recvWithRECVORIGDSTADDR(ctx context.Context, ipv6 bool, expectedDst *net.IP, port uint16) error { + // The net package doesn't give guaranteed access to a connection's + // underlying FD, and thus we cannot call getsockopt. We have to use + // traditional syscalls for IP_RECVORIGDSTADDR. + + // Create the listening socket. + var ( + family = syscall.AF_INET + level = syscall.SOL_IP + option = syscall.IP_RECVORIGDSTADDR + bindAddr syscall.Sockaddr = &syscall.SockaddrInet4{ + Port: int(port), + Addr: [4]byte{0, 0, 0, 0}, // INADDR_ANY + } + ) + if ipv6 { + family = syscall.AF_INET6 + level = syscall.SOL_IPV6 + option = 74 // IPV6_RECVORIGDSTADDR, which is missing from the syscall package. + bindAddr = &syscall.SockaddrInet6{ + Port: int(port), + Addr: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in6addr_any + } + } + sockfd, err := syscall.Socket(family, syscall.SOCK_DGRAM, 0) + if err != nil { + return fmt.Errorf("failed Socket(%d, %d, 0): %w", family, syscall.SOCK_DGRAM, err) + } + defer syscall.Close(sockfd) + + if err := syscall.Bind(sockfd, bindAddr); err != nil { + return fmt.Errorf("failed Bind(%d, %+v): %v", sockfd, bindAddr, err) + } + + // Enable IP_RECVORIGDSTADDR. + if err := syscall.SetsockoptInt(sockfd, level, option, 1); err != nil { + return fmt.Errorf("failed SetsockoptByte(%d, %d, %d, 1): %v", sockfd, level, option, err) + } + + addrCh := make(chan interface{}) + errCh := make(chan error) + go func() { + var addr interface{} + var err error + if ipv6 { + addr, err = recvOrigDstAddr6(sockfd) + } else { + addr, err = recvOrigDstAddr4(sockfd) + } + if err != nil { + errCh <- err + } else { + addrCh <- addr + } + }() + + // Wait to receive a packet. + var addr interface{} + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return err + case addr = <-addrCh: + } + + // Get a list of local IPs to verify that the packet now appears to have + // been sent to us. + var localAddrs []net.IP + if expectedDst != nil { + localAddrs = []net.IP{*expectedDst} + } else { + localAddrs, err = getInterfaceAddrs(ipv6) + if err != nil { + return fmt.Errorf("failed to get local interfaces: %w", err) + } + } + + // Verify that the address has the post-NAT port and address. + if ipv6 { + return addrMatches6(addr.(syscall.RawSockaddrInet6), localAddrs, redirectPort) + } + return addrMatches4(addr.(syscall.RawSockaddrInet4), localAddrs, redirectPort) +} + +func recvOrigDstAddr4(sockfd int) (syscall.RawSockaddrInet4, error) { + buf, err := recvOrigDstAddr(sockfd, syscall.SOL_IP, syscall.SizeofSockaddrInet4) + if err != nil { + return syscall.RawSockaddrInet4{}, err + } + var addr syscall.RawSockaddrInet4 + binary.Unmarshal(buf, usermem.ByteOrder, &addr) + return addr, nil +} + +func recvOrigDstAddr6(sockfd int) (syscall.RawSockaddrInet6, error) { + buf, err := recvOrigDstAddr(sockfd, syscall.SOL_IP, syscall.SizeofSockaddrInet6) + if err != nil { + return syscall.RawSockaddrInet6{}, err + } + var addr syscall.RawSockaddrInet6 + binary.Unmarshal(buf, usermem.ByteOrder, &addr) + return addr, nil +} + +func recvOrigDstAddr(sockfd int, level uintptr, addrSize int) ([]byte, error) { + buf := make([]byte, 64) + oob := make([]byte, syscall.CmsgSpace(addrSize)) + for { + _, oobn, _, _, err := syscall.Recvmsg( + sockfd, + buf, // Message buffer. + oob, // Out-of-band buffer. + 0) // Flags. + if errors.Is(err, syscall.EINTR) { + continue + } + if err != nil { + return nil, fmt.Errorf("failed when calling Recvmsg: %w", err) + } + oob = oob[:oobn] + + // Parse out the control message. + msgs, err := syscall.ParseSocketControlMessage(oob) + if err != nil { + return nil, fmt.Errorf("failed to parse control message: %w", err) + } + return msgs[0].Data, nil + } +} + +func addrMatches4(got syscall.RawSockaddrInet4, wantAddrs []net.IP, port uint16) error { + for _, wantAddr := range wantAddrs { + want := syscall.RawSockaddrInet4{ + Family: syscall.AF_INET, + Port: htons(port), + } + copy(want.Addr[:], wantAddr.To4()) + if got == want { + return nil + } + } + return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs) +} + +func addrMatches6(got syscall.RawSockaddrInet6, wantAddrs []net.IP, port uint16) error { + for _, wantAddr := range wantAddrs { + want := syscall.RawSockaddrInet6{ + Family: syscall.AF_INET6, + Port: htons(port), + } + copy(want.Addr[:], wantAddr.To16()) + if got == want { + return nil + } + } + return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs) +} -- cgit v1.2.3